From 215713e19924dff69d226a97f1860a5470464d15 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 25 Jul 2015 01:37:41 -0700 Subject: [SPARK-9334][SQL] Remove UnsafeRowConverter in favor of UnsafeProjection. The two are redundant. Once this patch is merged, I plan to remove the inbound conversions from unsafe aggregates. Author: Reynold Xin Closes #7658 from rxin/unsafeconverters and squashes the following commits: ed19e6c [Reynold Xin] Updated support types. 2a56d7e [Reynold Xin] [SPARK-9334][SQL] Remove UnsafeRowConverter in favor of UnsafeProjection. --- .../UnsafeFixedWidthAggregationMap.java | 55 ++-- .../sql/catalyst/expressions/UnsafeRowWriters.java | 83 +++++++ .../sql/catalyst/expressions/Projection.scala | 4 +- .../catalyst/expressions/UnsafeRowConverter.scala | 276 --------------------- .../codegen/GenerateUnsafeProjection.scala | 131 ++++++---- .../expressions/ExpressionEvalHelper.scala | 2 +- .../expressions/UnsafeRowConverterSuite.scala | 79 ++---- .../sql/execution/UnsafeRowSerializerSuite.scala | 17 +- .../org/apache/spark/unsafe/types/ByteArray.java | 38 +++ .../org/apache/spark/unsafe/types/UTF8String.java | 15 ++ 10 files changed, 262 insertions(+), 438 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala create mode 100644 unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java index 2f7e84a7f5..684de6e81d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -47,7 +47,7 @@ public final class UnsafeFixedWidthAggregationMap { /** * Encodes grouping keys as UnsafeRows. */ - private final UnsafeRowConverter groupingKeyToUnsafeRowConverter; + private final UnsafeProjection groupingKeyProjection; /** * A hashmap which maps from opaque bytearray keys to bytearray values. @@ -59,14 +59,6 @@ public final class UnsafeFixedWidthAggregationMap { */ private final UnsafeRow currentAggregationBuffer = new UnsafeRow(); - /** - * Scratch space that is used when encoding grouping keys into UnsafeRow format. - * - * By default, this is a 8 kb array, but it will grow as necessary in case larger keys are - * encountered. - */ - private byte[] groupingKeyConversionScratchSpace = new byte[1024 * 8]; - private final boolean enablePerfMetrics; /** @@ -112,26 +104,17 @@ public final class UnsafeFixedWidthAggregationMap { TaskMemoryManager memoryManager, int initialCapacity, boolean enablePerfMetrics) { - this.emptyAggregationBuffer = - convertToUnsafeRow(emptyAggregationBuffer, aggregationBufferSchema); this.aggregationBufferSchema = aggregationBufferSchema; - this.groupingKeyToUnsafeRowConverter = new UnsafeRowConverter(groupingKeySchema); + this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema); this.groupingKeySchema = groupingKeySchema; this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics); this.enablePerfMetrics = enablePerfMetrics; - } - /** - * Convert a Java object row into an UnsafeRow, allocating it into a new byte array. - */ - private static byte[] convertToUnsafeRow(InternalRow javaRow, StructType schema) { - final UnsafeRowConverter converter = new UnsafeRowConverter(schema); - final int size = converter.getSizeRequirement(javaRow); - final byte[] unsafeRow = new byte[size]; - final int writtenLength = - converter.writeRow(javaRow, unsafeRow, PlatformDependent.BYTE_ARRAY_OFFSET, size); - assert (writtenLength == unsafeRow.length): "Size requirement calculation was wrong!"; - return unsafeRow; + // Initialize the buffer for aggregation value + final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema); + this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); + assert(this.emptyAggregationBuffer.length == aggregationBufferSchema.length() * 8 + + UnsafeRow.calculateBitSetWidthInBytes(aggregationBufferSchema.length())); } /** @@ -139,30 +122,20 @@ public final class UnsafeFixedWidthAggregationMap { * return the same object. */ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { - final int groupingKeySize = groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey); - // Make sure that the buffer is large enough to hold the key. If it's not, grow it: - if (groupingKeySize > groupingKeyConversionScratchSpace.length) { - groupingKeyConversionScratchSpace = new byte[groupingKeySize]; - } - final int actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow( - groupingKey, - groupingKeyConversionScratchSpace, - PlatformDependent.BYTE_ARRAY_OFFSET, - groupingKeySize); - assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!"; + final UnsafeRow unsafeGroupingKeyRow = this.groupingKeyProjection.apply(groupingKey); // Probe our map using the serialized key final BytesToBytesMap.Location loc = map.lookup( - groupingKeyConversionScratchSpace, - PlatformDependent.BYTE_ARRAY_OFFSET, - groupingKeySize); + unsafeGroupingKeyRow.getBaseObject(), + unsafeGroupingKeyRow.getBaseOffset(), + unsafeGroupingKeyRow.getSizeInBytes()); if (!loc.isDefined()) { // This is the first time that we've seen this grouping key, so we'll insert a copy of the // empty aggregation buffer into the map: loc.putNewKey( - groupingKeyConversionScratchSpace, - PlatformDependent.BYTE_ARRAY_OFFSET, - groupingKeySize, + unsafeGroupingKeyRow.getBaseObject(), + unsafeGroupingKeyRow.getBaseOffset(), + unsafeGroupingKeyRow.getSizeInBytes(), emptyAggregationBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, emptyAggregationBuffer.length diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java new file mode 100644 index 0000000000..87521d1f23 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.types.ByteArray; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A set of helper methods to write data into {@link UnsafeRow}s, + * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}. + */ +public class UnsafeRowWriters { + + /** Writer for UTF8String. */ + public static class UTF8StringWriter { + + public static int getSize(UTF8String input) { + return ByteArrayMethods.roundNumberOfBytesToNearestWord(input.numBytes()); + } + + public static int write(UnsafeRow target, int ordinal, int cursor, UTF8String input) { + final long offset = target.getBaseOffset() + cursor; + final int numBytes = input.numBytes(); + + // zero-out the padding bytes + if ((numBytes & 0x07) > 0) { + PlatformDependent.UNSAFE.putLong( + target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + } + + // Write the string to the variable length portion. + input.writeToMemory(target.getBaseObject(), offset); + + // Set the fixed length portion. + target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); + return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + } + } + + /** Writer for bianry (byte array) type. */ + public static class BinaryWriter { + + public static int getSize(byte[] input) { + return ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length); + } + + public static int write(UnsafeRow target, int ordinal, int cursor, byte[] input) { + final long offset = target.getBaseOffset() + cursor; + final int numBytes = input.length; + + // zero-out the padding bytes + if ((numBytes & 0x07) > 0) { + PlatformDependent.UNSAFE.putLong( + target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + } + + // Write the string to the variable length portion. + ByteArray.writeToMemory(input, target.getBaseObject(), offset); + + // Set the fixed length portion. + target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); + return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + } + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 6023a2c564..fb873e7e99 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -90,8 +90,10 @@ object UnsafeProjection { * Seq[Expression]. */ def canSupport(schema: StructType): Boolean = canSupport(schema.fields.map(_.dataType)) - def canSupport(types: Array[DataType]): Boolean = types.forall(UnsafeColumnWriter.canEmbed(_)) def canSupport(exprs: Seq[Expression]): Boolean = canSupport(exprs.map(_.dataType).toArray) + private def canSupport(types: Array[DataType]): Boolean = { + types.forall(GenerateUnsafeProjection.canSupport) + } /** * Returns an UnsafeProjection for given StructType. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala deleted file mode 100644 index c47b16c0f8..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ /dev/null @@ -1,276 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import scala.util.Try - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent -import org.apache.spark.unsafe.array.ByteArrayMethods -import org.apache.spark.unsafe.types.UTF8String - - -/** - * Converts Rows into UnsafeRow format. This class is NOT thread-safe. - * - * @param fieldTypes the data types of the row's columns. - */ -class UnsafeRowConverter(fieldTypes: Array[DataType]) { - - def this(schema: StructType) { - this(schema.fields.map(_.dataType)) - } - - /** Re-used pointer to the unsafe row being written */ - private[this] val unsafeRow = new UnsafeRow() - - /** Functions for encoding each column */ - private[this] val writers: Array[UnsafeColumnWriter] = { - fieldTypes.map(t => UnsafeColumnWriter.forType(t)) - } - - /** The size, in bytes, of the fixed-length portion of the row, including the null bitmap */ - private[this] val fixedLengthSize: Int = - (8 * fieldTypes.length) + UnsafeRow.calculateBitSetWidthInBytes(fieldTypes.length) - - /** - * Compute the amount of space, in bytes, required to encode the given row. - */ - def getSizeRequirement(row: InternalRow): Int = { - var fieldNumber = 0 - var variableLengthFieldSize: Int = 0 - while (fieldNumber < writers.length) { - if (!row.isNullAt(fieldNumber)) { - variableLengthFieldSize += writers(fieldNumber).getSize(row, fieldNumber) - } - fieldNumber += 1 - } - fixedLengthSize + variableLengthFieldSize - } - - /** - * Convert the given row into UnsafeRow format. - * - * @param row the row to convert - * @param baseObject the base object of the destination address - * @param baseOffset the base offset of the destination address - * @param rowLengthInBytes the length calculated by `getSizeRequirement(row)` - * @return the number of bytes written. This should be equal to `getSizeRequirement(row)`. - */ - def writeRow( - row: InternalRow, - baseObject: Object, - baseOffset: Long, - rowLengthInBytes: Int): Int = { - unsafeRow.pointTo(baseObject, baseOffset, writers.length, rowLengthInBytes) - - if (writers.length > 0) { - // zero-out the bitset - var n = writers.length / 64 - while (n >= 0) { - PlatformDependent.UNSAFE.putLong( - unsafeRow.getBaseObject, - unsafeRow.getBaseOffset + n * 8, - 0L) - n -= 1 - } - } - - var fieldNumber = 0 - var appendCursor: Int = fixedLengthSize - while (fieldNumber < writers.length) { - if (row.isNullAt(fieldNumber)) { - unsafeRow.setNullAt(fieldNumber) - } else { - appendCursor += writers(fieldNumber).write(row, unsafeRow, fieldNumber, appendCursor) - } - fieldNumber += 1 - } - appendCursor - } - -} - -/** - * Function for writing a column into an UnsafeRow. - */ -private abstract class UnsafeColumnWriter { - /** - * Write a value into an UnsafeRow. - * - * @param source the row being converted - * @param target a pointer to the converted unsafe row - * @param column the column to write - * @param appendCursor the offset from the start of the unsafe row to the end of the row; - * used for calculating where variable-length data should be written - * @return the number of variable-length bytes written - */ - def write(source: InternalRow, target: UnsafeRow, column: Int, appendCursor: Int): Int - - /** - * Return the number of bytes that are needed to write this variable-length value. - */ - def getSize(source: InternalRow, column: Int): Int -} - -private object UnsafeColumnWriter { - - def forType(dataType: DataType): UnsafeColumnWriter = { - dataType match { - case NullType => NullUnsafeColumnWriter - case BooleanType => BooleanUnsafeColumnWriter - case ByteType => ByteUnsafeColumnWriter - case ShortType => ShortUnsafeColumnWriter - case IntegerType | DateType => IntUnsafeColumnWriter - case LongType | TimestampType => LongUnsafeColumnWriter - case FloatType => FloatUnsafeColumnWriter - case DoubleType => DoubleUnsafeColumnWriter - case StringType => StringUnsafeColumnWriter - case BinaryType => BinaryUnsafeColumnWriter - case t => - throw new UnsupportedOperationException(s"Do not know how to write columns of type $t") - } - } - - /** - * Returns whether the dataType can be embedded into UnsafeRow (not using ObjectPool). - */ - def canEmbed(dataType: DataType): Boolean = Try(forType(dataType)).isSuccess -} - -// ------------------------------------------------------------------------------------------------ - -private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter { - // Primitives don't write to the variable-length region: - def getSize(sourceRow: InternalRow, column: Int): Int = 0 -} - -private object NullUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - target.setNullAt(column) - 0 - } -} - -private object BooleanUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - target.setBoolean(column, source.getBoolean(column)) - 0 - } -} - -private object ByteUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - target.setByte(column, source.getByte(column)) - 0 - } -} - -private object ShortUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - target.setShort(column, source.getShort(column)) - 0 - } -} - -private object IntUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - target.setInt(column, source.getInt(column)) - 0 - } -} - -private object LongUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - target.setLong(column, source.getLong(column)) - 0 - } -} - -private object FloatUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - target.setFloat(column, source.getFloat(column)) - 0 - } -} - -private object DoubleUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - target.setDouble(column, source.getDouble(column)) - 0 - } -} - -private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter { - - protected[this] def isString: Boolean - protected[this] def getBytes(source: InternalRow, column: Int): Array[Byte] - - override def getSize(source: InternalRow, column: Int): Int = { - val numBytes = getBytes(source, column).length - ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) - } - - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - val bytes = getBytes(source, column) - write(target, bytes, column, cursor) - } - - def write(target: UnsafeRow, bytes: Array[Byte], column: Int, cursor: Int): Int = { - val offset = target.getBaseOffset + cursor - val numBytes = bytes.length - if ((numBytes & 0x07) > 0) { - // zero-out the padding bytes - PlatformDependent.UNSAFE.putLong(target.getBaseObject, offset + ((numBytes >> 3) << 3), 0L) - } - PlatformDependent.copyMemory( - bytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - target.getBaseObject, - offset, - numBytes - ) - target.setLong(column, (cursor.toLong << 32) | numBytes.toLong) - ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) - } -} - -private object StringUnsafeColumnWriter extends BytesUnsafeColumnWriter { - protected[this] def isString: Boolean = true - def getBytes(source: InternalRow, column: Int): Array[Byte] = { - source.getAs[UTF8String](column).getBytes - } - // TODO(davies): refactor this - // specialized for codegen - def getSize(value: UTF8String): Int = - ByteArrayMethods.roundNumberOfBytesToNearestWord(value.numBytes()) - def write(target: UnsafeRow, value: UTF8String, column: Int, cursor: Int): Int = { - write(target, value.getBytes, column, cursor) - } -} - -private object BinaryUnsafeColumnWriter extends BytesUnsafeColumnWriter { - protected[this] override def isString: Boolean = false - override def getBytes(source: InternalRow, column: Int): Array[Byte] = { - source.getAs[Array[Byte]](column) - } - // specialized for codegen - def getSize(value: Array[Byte]): Int = - ByteArrayMethods.roundNumberOfBytesToNearestWord(value.length) -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 0320bcb827..afd0d9cfa1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{NullType, BinaryType, StringType} - +import org.apache.spark.sql.types._ /** * Generates a [[Projection]] that returns an [[UnsafeRow]]. @@ -32,25 +31,43 @@ import org.apache.spark.sql.types.{NullType, BinaryType, StringType} */ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] { - protected def canonicalize(in: Seq[Expression]): Seq[Expression] = - in.map(ExpressionCanonicalizer.execute) + private val StringWriter = classOf[UnsafeRowWriters.UTF8StringWriter].getName + private val BinaryWriter = classOf[UnsafeRowWriters.BinaryWriter].getName - protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = - in.map(BindReferences.bindReference(_, inputSchema)) + /** Returns true iff we support this data type. */ + def canSupport(dataType: DataType): Boolean = dataType match { + case t: AtomicType if !t.isInstanceOf[DecimalType] => true + case NullType => true + case _ => false + } + + /** + * Generates the code to create an [[UnsafeRow]] object based on the input expressions. + * @param ctx context for code generation + * @param ev specifies the name of the variable for the output [[UnsafeRow]] object + * @param expressions input expressions + * @return generated code to put the expression output into an [[UnsafeRow]] + */ + def createCode(ctx: CodeGenContext, ev: GeneratedExpressionCode, expressions: Seq[Expression]) + : String = { + + val ret = ev.primitive + ctx.addMutableState("UnsafeRow", ret, s"$ret = new UnsafeRow();") + val bufferTerm = ctx.freshName("buffer") + ctx.addMutableState("byte[]", bufferTerm, s"$bufferTerm = new byte[64];") + val cursorTerm = ctx.freshName("cursor") + val numBytesTerm = ctx.freshName("numBytes") - protected def create(expressions: Seq[Expression]): UnsafeProjection = { - val ctx = newCodeGenContext() val exprs = expressions.map(_.gen(ctx)) val allExprs = exprs.map(_.code).mkString("\n") val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) - val stringWriter = "org.apache.spark.sql.catalyst.expressions.StringUnsafeColumnWriter" - val binaryWriter = "org.apache.spark.sql.catalyst.expressions.BinaryUnsafeColumnWriter" + val additionalSize = expressions.zipWithIndex.map { case (e, i) => e.dataType match { case StringType => - s" + (${exprs(i).isNull} ? 0 : $stringWriter.getSize(${exprs(i).primitive}))" + s" + (${exprs(i).isNull} ? 0 : $StringWriter.getSize(${exprs(i).primitive}))" case BinaryType => - s" + (${exprs(i).isNull} ? 0 : $binaryWriter.getSize(${exprs(i).primitive}))" + s" + (${exprs(i).isNull} ? 0 : $BinaryWriter.getSize(${exprs(i).primitive}))" case _ => "" } }.mkString("") @@ -58,63 +75,85 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val writers = expressions.zipWithIndex.map { case (e, i) => val update = e.dataType match { case dt if ctx.isPrimitiveType(dt) => - s"${ctx.setColumn("target", dt, i, exprs(i).primitive)}" + s"${ctx.setColumn(ret, dt, i, exprs(i).primitive)}" case StringType => - s"cursor += $stringWriter.write(target, ${exprs(i).primitive}, $i, cursor)" + s"$cursorTerm += $StringWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})" case BinaryType => - s"cursor += $binaryWriter.write(target, ${exprs(i).primitive}, $i, cursor)" + s"$cursorTerm += $BinaryWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})" case NullType => "" case _ => throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}") } s"""if (${exprs(i).isNull}) { - target.setNullAt($i); + $ret.setNullAt($i); } else { $update; }""" }.mkString("\n ") - val code = s""" - private $exprType[] expressions; + s""" + $allExprs + int $numBytesTerm = $fixedSize $additionalSize; + if ($numBytesTerm > $bufferTerm.length) { + $bufferTerm = new byte[$numBytesTerm]; + } - public Object generate($exprType[] expr) { - this.expressions = expr; - return new SpecificProjection(); - } + $ret.pointTo( + $bufferTerm, + org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET, + ${expressions.size}, + $numBytesTerm); + int $cursorTerm = $fixedSize; - class SpecificProjection extends ${classOf[UnsafeProjection].getName} { - private UnsafeRow target = new UnsafeRow(); - private byte[] buffer = new byte[64]; - ${declareMutableStates(ctx)} + $writers + boolean ${ev.isNull} = false; + """ + } - public SpecificProjection() { - ${initMutableStates(ctx)} - } + protected def canonicalize(in: Seq[Expression]): Seq[Expression] = + in.map(ExpressionCanonicalizer.execute) + + protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = + in.map(BindReferences.bindReference(_, inputSchema)) + + protected def create(expressions: Seq[Expression]): UnsafeProjection = { + val ctx = newCodeGenContext() + + val isNull = ctx.freshName("retIsNull") + val primitive = ctx.freshName("retValue") + val eval = GeneratedExpressionCode("", isNull, primitive) + eval.code = createCode(ctx, eval, expressions) - // Scala.Function1 need this - public Object apply(Object row) { - return apply((InternalRow) row); + val code = s""" + private $exprType[] expressions; + + public Object generate($exprType[] expr) { + this.expressions = expr; + return new SpecificProjection(); } - public UnsafeRow apply(InternalRow i) { - $allExprs + class SpecificProjection extends ${classOf[UnsafeProjection].getName} { + + ${declareMutableStates(ctx)} + + public SpecificProjection() { + ${initMutableStates(ctx)} + } + + // Scala.Function1 need this + public Object apply(Object row) { + return apply((InternalRow) row); + } - // additionalSize had '+' in the beginning - int numBytes = $fixedSize $additionalSize; - if (numBytes > buffer.length) { - buffer = new byte[numBytes]; + public UnsafeRow apply(InternalRow i) { + ${eval.code} + return ${eval.primitive}; } - target.pointTo(buffer, org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET, - ${expressions.size}, numBytes); - int cursor = $fixedSize; - $writers - return target; } - } - """ + """ - logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") + logDebug(s"code for ${expressions.mkString(",")}:\n$code") val c = compile(code) c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 6e17ffcda9..4930219aa6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -43,7 +43,7 @@ trait ExpressionEvalHelper { checkEvaluationWithoutCodegen(expression, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow) checkEvaluationWithGeneratedProjection(expression, catalystValue, inputRow) - if (UnsafeColumnWriter.canEmbed(expression.dataType)) { + if (GenerateUnsafeProjection.canSupport(expression.dataType)) { checkEvalutionWithUnsafeProjection(expression, catalystValue, inputRow) } checkEvaluationWithOptimization(expression, catalystValue, inputRow) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index a5d9806c20..4606bcb573 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.UTF8String @@ -34,22 +33,15 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { test("basic conversion with only primitive types") { val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) - val converter = new UnsafeRowConverter(fieldTypes) + val converter = UnsafeProjection.create(fieldTypes) val row = new SpecificMutableRow(fieldTypes) row.setLong(0, 0) row.setLong(1, 1) row.setInt(2, 2) - val sizeRequired: Int = converter.getSizeRequirement(row) - assert(sizeRequired === 8 + (3 * 8)) - val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = - converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired) - assert(numBytesWritten === sizeRequired) - - val unsafeRow = new UnsafeRow() - unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired) + val unsafeRow: UnsafeRow = converter.apply(row) + assert(converter.apply(row).getSizeInBytes === 8 + (3 * 8)) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getLong(1) === 1) assert(unsafeRow.getInt(2) === 2) @@ -73,25 +65,18 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { test("basic conversion with primitive, string and binary types") { val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType) - val converter = new UnsafeRowConverter(fieldTypes) + val converter = UnsafeProjection.create(fieldTypes) val row = new SpecificMutableRow(fieldTypes) row.setLong(0, 0) row.update(1, UTF8String.fromString("Hello")) row.update(2, "World".getBytes) - val sizeRequired: Int = converter.getSizeRequirement(row) - assert(sizeRequired === 8 + (8 * 3) + + val unsafeRow: UnsafeRow = converter.apply(row) + assert(unsafeRow.getSizeInBytes === 8 + (8 * 3) + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) + ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length)) - val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow( - row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired) - assert(numBytesWritten === sizeRequired) - - val unsafeRow = new UnsafeRow() - unsafeRow.pointTo( - buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired) + assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getString(1) === "Hello") assert(unsafeRow.getBinary(2) === "World".getBytes) @@ -99,7 +84,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { test("basic conversion with primitive, string, date and timestamp types") { val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType) - val converter = new UnsafeRowConverter(fieldTypes) + val converter = UnsafeProjection.create(fieldTypes) val row = new SpecificMutableRow(fieldTypes) row.setLong(0, 0) @@ -107,17 +92,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { row.update(2, DateTimeUtils.fromJavaDate(Date.valueOf("1970-01-01"))) row.update(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-05-08 08:10:25"))) - val sizeRequired: Int = converter.getSizeRequirement(row) - assert(sizeRequired === 8 + (8 * 4) + + val unsafeRow: UnsafeRow = converter.apply(row) + assert(unsafeRow.getSizeInBytes === 8 + (8 * 4) + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length)) - val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = - converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired) - assert(numBytesWritten === sizeRequired) - - val unsafeRow = new UnsafeRow() - unsafeRow.pointTo( - buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired) + assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getString(1) === "Hello") // Date is represented as Int in unsafeRow @@ -148,26 +126,18 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { // DecimalType.Default, // ArrayType(IntegerType) ) - val converter = new UnsafeRowConverter(fieldTypes) + val converter = UnsafeProjection.create(fieldTypes) val rowWithAllNullColumns: InternalRow = { val r = new SpecificMutableRow(fieldTypes) - for (i <- 0 to fieldTypes.length - 1) { + for (i <- fieldTypes.indices) { r.setNullAt(i) } r } - val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns) - val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow( - rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, - sizeRequired) - assert(numBytesWritten === sizeRequired) + val createdFromNull: UnsafeRow = converter.apply(rowWithAllNullColumns) - val createdFromNull = new UnsafeRow() - createdFromNull.pointTo( - createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired) for (i <- fieldTypes.indices) { assert(createdFromNull.isNullAt(i)) } @@ -202,15 +172,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { // r.update(11, Array(11)) r } - val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8 + 2) - converter.writeRow( - rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, - sizeRequired) - val setToNullAfterCreation = new UnsafeRow() - setToNullAfterCreation.pointTo( - setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, - sizeRequired) + val setToNullAfterCreation = converter.apply(rowWithNoNullColumns) assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1)) assert(setToNullAfterCreation.getByte(2) === rowWithNoNullColumns.getByte(2)) @@ -228,8 +191,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { setToNullAfterCreation.setNullAt(i) } // There are some garbage left in the var-length area - assert(Arrays.equals(createdFromNullBuffer, - java.util.Arrays.copyOf(setToNullAfterCreationBuffer, sizeRequired / 8))) + assert(Arrays.equals(createdFromNull.getBytes, setToNullAfterCreation.getBytes())) setToNullAfterCreation.setNullAt(0) setToNullAfterCreation.setBoolean(1, false) @@ -269,12 +231,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { row2.setFloat(0, java.lang.Float.intBitsToFloat(0x7fffffff)) row2.setDouble(1, java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)) - val converter = new UnsafeRowConverter(fieldTypes) - val row1Buffer = new Array[Byte](converter.getSizeRequirement(row1)) - val row2Buffer = new Array[Byte](converter.getSizeRequirement(row2)) - converter.writeRow(row1, row1Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row1Buffer.length) - converter.writeRow(row2, row2Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row2Buffer.length) - - assert(row1Buffer.toSeq === row2Buffer.toSeq) + val converter = UnsafeProjection.create(fieldTypes) + assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index a1e1695717..40b47ae18d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -22,29 +22,22 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeRowConverter} +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent class UnsafeRowSerializerSuite extends SparkFunSuite { private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = { val internalRow = CatalystTypeConverters.convertToCatalyst(row).asInstanceOf[InternalRow] - val rowConverter = new UnsafeRowConverter(schema) - val rowSizeInBytes = rowConverter.getSizeRequirement(internalRow) - val byteArray = new Array[Byte](rowSizeInBytes) - rowConverter.writeRow( - internalRow, byteArray, PlatformDependent.BYTE_ARRAY_OFFSET, rowSizeInBytes) - val unsafeRow = new UnsafeRow() - unsafeRow.pointTo(byteArray, PlatformDependent.BYTE_ARRAY_OFFSET, row.length, rowSizeInBytes) - unsafeRow + val converter = UnsafeProjection.create(schema) + converter.apply(internalRow) } - ignore("toUnsafeRow() test helper method") { + test("toUnsafeRow() test helper method") { // This currently doesnt work because the generic getter throws an exception. val row = Row("Hello", 123) val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType)) - assert(row.getString(0) === unsafeRow.get(0).toString) + assert(row.getString(0) === unsafeRow.getUTF8String(0).toString) assert(row.getInt(1) === unsafeRow.getInt(1)) } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java new file mode 100644 index 0000000000..69b0e206ce --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.types; + +import org.apache.spark.unsafe.PlatformDependent; + +public class ByteArray { + + /** + * Writes the content of a byte array into a memory address, identified by an object and an + * offset. The target memory address must already been allocated, and have enough space to + * hold all the bytes in this string. + */ + public static void writeToMemory(byte[] src, Object target, long targetOffset) { + PlatformDependent.copyMemory( + src, + PlatformDependent.BYTE_ARRAY_OFFSET, + target, + targetOffset, + src.length + ); + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 6d8dcb1cbf..85381cf0ef 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -95,6 +95,21 @@ public final class UTF8String implements Comparable, Serializable { this.numBytes = size; } + /** + * Writes the content of this string into a memory address, identified by an object and an offset. + * The target memory address must already been allocated, and have enough space to hold all the + * bytes in this string. + */ + public void writeToMemory(Object target, long targetOffset) { + PlatformDependent.copyMemory( + base, + offset, + target, + targetOffset, + numBytes + ); + } + /** * Returns the number of bytes for a code point with the first byte as `b` * @param b The first byte of a code point -- cgit v1.2.3