aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-07-25 01:37:41 -0700
committerReynold Xin <rxin@databricks.com>2015-07-25 01:37:41 -0700
commit215713e19924dff69d226a97f1860a5470464d15 (patch)
treec9ab7c622d1af96b07346cf0dd3fd058c59f057b
parentf0ebab3f6d3a9231474acf20110db72c0fb51882 (diff)
downloadspark-215713e19924dff69d226a97f1860a5470464d15.tar.gz
spark-215713e19924dff69d226a97f1860a5470464d15.tar.bz2
spark-215713e19924dff69d226a97f1860a5470464d15.zip
[SPARK-9334][SQL] Remove UnsafeRowConverter in favor of UnsafeProjection.
The two are redundant. Once this patch is merged, I plan to remove the inbound conversions from unsafe aggregates. Author: Reynold Xin <rxin@databricks.com> Closes #7658 from rxin/unsafeconverters and squashes the following commits: ed19e6c [Reynold Xin] Updated support types. 2a56d7e [Reynold Xin] [SPARK-9334][SQL] Remove UnsafeRowConverter in favor of UnsafeProjection.
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java55
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java83
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala276
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala131
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala79
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala17
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java38
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java15
10 files changed, 262 insertions, 438 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
index 2f7e84a7f5..684de6e81d 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
@@ -47,7 +47,7 @@ public final class UnsafeFixedWidthAggregationMap {
/**
* Encodes grouping keys as UnsafeRows.
*/
- private final UnsafeRowConverter groupingKeyToUnsafeRowConverter;
+ private final UnsafeProjection groupingKeyProjection;
/**
* A hashmap which maps from opaque bytearray keys to bytearray values.
@@ -59,14 +59,6 @@ public final class UnsafeFixedWidthAggregationMap {
*/
private final UnsafeRow currentAggregationBuffer = new UnsafeRow();
- /**
- * Scratch space that is used when encoding grouping keys into UnsafeRow format.
- *
- * By default, this is a 8 kb array, but it will grow as necessary in case larger keys are
- * encountered.
- */
- private byte[] groupingKeyConversionScratchSpace = new byte[1024 * 8];
-
private final boolean enablePerfMetrics;
/**
@@ -112,26 +104,17 @@ public final class UnsafeFixedWidthAggregationMap {
TaskMemoryManager memoryManager,
int initialCapacity,
boolean enablePerfMetrics) {
- this.emptyAggregationBuffer =
- convertToUnsafeRow(emptyAggregationBuffer, aggregationBufferSchema);
this.aggregationBufferSchema = aggregationBufferSchema;
- this.groupingKeyToUnsafeRowConverter = new UnsafeRowConverter(groupingKeySchema);
+ this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema);
this.groupingKeySchema = groupingKeySchema;
this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics);
this.enablePerfMetrics = enablePerfMetrics;
- }
- /**
- * Convert a Java object row into an UnsafeRow, allocating it into a new byte array.
- */
- private static byte[] convertToUnsafeRow(InternalRow javaRow, StructType schema) {
- final UnsafeRowConverter converter = new UnsafeRowConverter(schema);
- final int size = converter.getSizeRequirement(javaRow);
- final byte[] unsafeRow = new byte[size];
- final int writtenLength =
- converter.writeRow(javaRow, unsafeRow, PlatformDependent.BYTE_ARRAY_OFFSET, size);
- assert (writtenLength == unsafeRow.length): "Size requirement calculation was wrong!";
- return unsafeRow;
+ // Initialize the buffer for aggregation value
+ final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema);
+ this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();
+ assert(this.emptyAggregationBuffer.length == aggregationBufferSchema.length() * 8 +
+ UnsafeRow.calculateBitSetWidthInBytes(aggregationBufferSchema.length()));
}
/**
@@ -139,30 +122,20 @@ public final class UnsafeFixedWidthAggregationMap {
* return the same object.
*/
public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
- final int groupingKeySize = groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey);
- // Make sure that the buffer is large enough to hold the key. If it's not, grow it:
- if (groupingKeySize > groupingKeyConversionScratchSpace.length) {
- groupingKeyConversionScratchSpace = new byte[groupingKeySize];
- }
- final int actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow(
- groupingKey,
- groupingKeyConversionScratchSpace,
- PlatformDependent.BYTE_ARRAY_OFFSET,
- groupingKeySize);
- assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!";
+ final UnsafeRow unsafeGroupingKeyRow = this.groupingKeyProjection.apply(groupingKey);
// Probe our map using the serialized key
final BytesToBytesMap.Location loc = map.lookup(
- groupingKeyConversionScratchSpace,
- PlatformDependent.BYTE_ARRAY_OFFSET,
- groupingKeySize);
+ unsafeGroupingKeyRow.getBaseObject(),
+ unsafeGroupingKeyRow.getBaseOffset(),
+ unsafeGroupingKeyRow.getSizeInBytes());
if (!loc.isDefined()) {
// This is the first time that we've seen this grouping key, so we'll insert a copy of the
// empty aggregation buffer into the map:
loc.putNewKey(
- groupingKeyConversionScratchSpace,
- PlatformDependent.BYTE_ARRAY_OFFSET,
- groupingKeySize,
+ unsafeGroupingKeyRow.getBaseObject(),
+ unsafeGroupingKeyRow.getBaseOffset(),
+ unsafeGroupingKeyRow.getSizeInBytes(),
emptyAggregationBuffer,
PlatformDependent.BYTE_ARRAY_OFFSET,
emptyAggregationBuffer.length
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
new file mode 100644
index 0000000000..87521d1f23
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions;
+
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.types.ByteArray;
+import org.apache.spark.unsafe.types.UTF8String;
+
+/**
+ * A set of helper methods to write data into {@link UnsafeRow}s,
+ * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}.
+ */
+public class UnsafeRowWriters {
+
+ /** Writer for UTF8String. */
+ public static class UTF8StringWriter {
+
+ public static int getSize(UTF8String input) {
+ return ByteArrayMethods.roundNumberOfBytesToNearestWord(input.numBytes());
+ }
+
+ public static int write(UnsafeRow target, int ordinal, int cursor, UTF8String input) {
+ final long offset = target.getBaseOffset() + cursor;
+ final int numBytes = input.numBytes();
+
+ // zero-out the padding bytes
+ if ((numBytes & 0x07) > 0) {
+ PlatformDependent.UNSAFE.putLong(
+ target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L);
+ }
+
+ // Write the string to the variable length portion.
+ input.writeToMemory(target.getBaseObject(), offset);
+
+ // Set the fixed length portion.
+ target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes));
+ return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
+ }
+ }
+
+ /** Writer for bianry (byte array) type. */
+ public static class BinaryWriter {
+
+ public static int getSize(byte[] input) {
+ return ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length);
+ }
+
+ public static int write(UnsafeRow target, int ordinal, int cursor, byte[] input) {
+ final long offset = target.getBaseOffset() + cursor;
+ final int numBytes = input.length;
+
+ // zero-out the padding bytes
+ if ((numBytes & 0x07) > 0) {
+ PlatformDependent.UNSAFE.putLong(
+ target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L);
+ }
+
+ // Write the string to the variable length portion.
+ ByteArray.writeToMemory(input, target.getBaseObject(), offset);
+
+ // Set the fixed length portion.
+ target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes));
+ return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
+ }
+ }
+
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 6023a2c564..fb873e7e99 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -90,8 +90,10 @@ object UnsafeProjection {
* Seq[Expression].
*/
def canSupport(schema: StructType): Boolean = canSupport(schema.fields.map(_.dataType))
- def canSupport(types: Array[DataType]): Boolean = types.forall(UnsafeColumnWriter.canEmbed(_))
def canSupport(exprs: Seq[Expression]): Boolean = canSupport(exprs.map(_.dataType).toArray)
+ private def canSupport(types: Array[DataType]): Boolean = {
+ types.forall(GenerateUnsafeProjection.canSupport)
+ }
/**
* Returns an UnsafeProjection for given StructType.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
deleted file mode 100644
index c47b16c0f8..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
+++ /dev/null
@@ -1,276 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import scala.util.Try
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.PlatformDependent
-import org.apache.spark.unsafe.array.ByteArrayMethods
-import org.apache.spark.unsafe.types.UTF8String
-
-
-/**
- * Converts Rows into UnsafeRow format. This class is NOT thread-safe.
- *
- * @param fieldTypes the data types of the row's columns.
- */
-class UnsafeRowConverter(fieldTypes: Array[DataType]) {
-
- def this(schema: StructType) {
- this(schema.fields.map(_.dataType))
- }
-
- /** Re-used pointer to the unsafe row being written */
- private[this] val unsafeRow = new UnsafeRow()
-
- /** Functions for encoding each column */
- private[this] val writers: Array[UnsafeColumnWriter] = {
- fieldTypes.map(t => UnsafeColumnWriter.forType(t))
- }
-
- /** The size, in bytes, of the fixed-length portion of the row, including the null bitmap */
- private[this] val fixedLengthSize: Int =
- (8 * fieldTypes.length) + UnsafeRow.calculateBitSetWidthInBytes(fieldTypes.length)
-
- /**
- * Compute the amount of space, in bytes, required to encode the given row.
- */
- def getSizeRequirement(row: InternalRow): Int = {
- var fieldNumber = 0
- var variableLengthFieldSize: Int = 0
- while (fieldNumber < writers.length) {
- if (!row.isNullAt(fieldNumber)) {
- variableLengthFieldSize += writers(fieldNumber).getSize(row, fieldNumber)
- }
- fieldNumber += 1
- }
- fixedLengthSize + variableLengthFieldSize
- }
-
- /**
- * Convert the given row into UnsafeRow format.
- *
- * @param row the row to convert
- * @param baseObject the base object of the destination address
- * @param baseOffset the base offset of the destination address
- * @param rowLengthInBytes the length calculated by `getSizeRequirement(row)`
- * @return the number of bytes written. This should be equal to `getSizeRequirement(row)`.
- */
- def writeRow(
- row: InternalRow,
- baseObject: Object,
- baseOffset: Long,
- rowLengthInBytes: Int): Int = {
- unsafeRow.pointTo(baseObject, baseOffset, writers.length, rowLengthInBytes)
-
- if (writers.length > 0) {
- // zero-out the bitset
- var n = writers.length / 64
- while (n >= 0) {
- PlatformDependent.UNSAFE.putLong(
- unsafeRow.getBaseObject,
- unsafeRow.getBaseOffset + n * 8,
- 0L)
- n -= 1
- }
- }
-
- var fieldNumber = 0
- var appendCursor: Int = fixedLengthSize
- while (fieldNumber < writers.length) {
- if (row.isNullAt(fieldNumber)) {
- unsafeRow.setNullAt(fieldNumber)
- } else {
- appendCursor += writers(fieldNumber).write(row, unsafeRow, fieldNumber, appendCursor)
- }
- fieldNumber += 1
- }
- appendCursor
- }
-
-}
-
-/**
- * Function for writing a column into an UnsafeRow.
- */
-private abstract class UnsafeColumnWriter {
- /**
- * Write a value into an UnsafeRow.
- *
- * @param source the row being converted
- * @param target a pointer to the converted unsafe row
- * @param column the column to write
- * @param appendCursor the offset from the start of the unsafe row to the end of the row;
- * used for calculating where variable-length data should be written
- * @return the number of variable-length bytes written
- */
- def write(source: InternalRow, target: UnsafeRow, column: Int, appendCursor: Int): Int
-
- /**
- * Return the number of bytes that are needed to write this variable-length value.
- */
- def getSize(source: InternalRow, column: Int): Int
-}
-
-private object UnsafeColumnWriter {
-
- def forType(dataType: DataType): UnsafeColumnWriter = {
- dataType match {
- case NullType => NullUnsafeColumnWriter
- case BooleanType => BooleanUnsafeColumnWriter
- case ByteType => ByteUnsafeColumnWriter
- case ShortType => ShortUnsafeColumnWriter
- case IntegerType | DateType => IntUnsafeColumnWriter
- case LongType | TimestampType => LongUnsafeColumnWriter
- case FloatType => FloatUnsafeColumnWriter
- case DoubleType => DoubleUnsafeColumnWriter
- case StringType => StringUnsafeColumnWriter
- case BinaryType => BinaryUnsafeColumnWriter
- case t =>
- throw new UnsupportedOperationException(s"Do not know how to write columns of type $t")
- }
- }
-
- /**
- * Returns whether the dataType can be embedded into UnsafeRow (not using ObjectPool).
- */
- def canEmbed(dataType: DataType): Boolean = Try(forType(dataType)).isSuccess
-}
-
-// ------------------------------------------------------------------------------------------------
-
-private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter {
- // Primitives don't write to the variable-length region:
- def getSize(sourceRow: InternalRow, column: Int): Int = 0
-}
-
-private object NullUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter {
- override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
- target.setNullAt(column)
- 0
- }
-}
-
-private object BooleanUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter {
- override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
- target.setBoolean(column, source.getBoolean(column))
- 0
- }
-}
-
-private object ByteUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter {
- override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
- target.setByte(column, source.getByte(column))
- 0
- }
-}
-
-private object ShortUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter {
- override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
- target.setShort(column, source.getShort(column))
- 0
- }
-}
-
-private object IntUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter {
- override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
- target.setInt(column, source.getInt(column))
- 0
- }
-}
-
-private object LongUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter {
- override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
- target.setLong(column, source.getLong(column))
- 0
- }
-}
-
-private object FloatUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter {
- override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
- target.setFloat(column, source.getFloat(column))
- 0
- }
-}
-
-private object DoubleUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter {
- override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
- target.setDouble(column, source.getDouble(column))
- 0
- }
-}
-
-private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter {
-
- protected[this] def isString: Boolean
- protected[this] def getBytes(source: InternalRow, column: Int): Array[Byte]
-
- override def getSize(source: InternalRow, column: Int): Int = {
- val numBytes = getBytes(source, column).length
- ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
- }
-
- override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
- val bytes = getBytes(source, column)
- write(target, bytes, column, cursor)
- }
-
- def write(target: UnsafeRow, bytes: Array[Byte], column: Int, cursor: Int): Int = {
- val offset = target.getBaseOffset + cursor
- val numBytes = bytes.length
- if ((numBytes & 0x07) > 0) {
- // zero-out the padding bytes
- PlatformDependent.UNSAFE.putLong(target.getBaseObject, offset + ((numBytes >> 3) << 3), 0L)
- }
- PlatformDependent.copyMemory(
- bytes,
- PlatformDependent.BYTE_ARRAY_OFFSET,
- target.getBaseObject,
- offset,
- numBytes
- )
- target.setLong(column, (cursor.toLong << 32) | numBytes.toLong)
- ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
- }
-}
-
-private object StringUnsafeColumnWriter extends BytesUnsafeColumnWriter {
- protected[this] def isString: Boolean = true
- def getBytes(source: InternalRow, column: Int): Array[Byte] = {
- source.getAs[UTF8String](column).getBytes
- }
- // TODO(davies): refactor this
- // specialized for codegen
- def getSize(value: UTF8String): Int =
- ByteArrayMethods.roundNumberOfBytesToNearestWord(value.numBytes())
- def write(target: UnsafeRow, value: UTF8String, column: Int, cursor: Int): Int = {
- write(target, value.getBytes, column, cursor)
- }
-}
-
-private object BinaryUnsafeColumnWriter extends BytesUnsafeColumnWriter {
- protected[this] override def isString: Boolean = false
- override def getBytes(source: InternalRow, column: Int): Array[Byte] = {
- source.getAs[Array[Byte]](column)
- }
- // specialized for codegen
- def getSize(value: Array[Byte]): Int =
- ByteArrayMethods.roundNumberOfBytesToNearestWord(value.length)
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 0320bcb827..afd0d9cfa1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -18,8 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types.{NullType, BinaryType, StringType}
-
+import org.apache.spark.sql.types._
/**
* Generates a [[Projection]] that returns an [[UnsafeRow]].
@@ -32,25 +31,43 @@ import org.apache.spark.sql.types.{NullType, BinaryType, StringType}
*/
object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] {
- protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
- in.map(ExpressionCanonicalizer.execute)
+ private val StringWriter = classOf[UnsafeRowWriters.UTF8StringWriter].getName
+ private val BinaryWriter = classOf[UnsafeRowWriters.BinaryWriter].getName
- protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
- in.map(BindReferences.bindReference(_, inputSchema))
+ /** Returns true iff we support this data type. */
+ def canSupport(dataType: DataType): Boolean = dataType match {
+ case t: AtomicType if !t.isInstanceOf[DecimalType] => true
+ case NullType => true
+ case _ => false
+ }
+
+ /**
+ * Generates the code to create an [[UnsafeRow]] object based on the input expressions.
+ * @param ctx context for code generation
+ * @param ev specifies the name of the variable for the output [[UnsafeRow]] object
+ * @param expressions input expressions
+ * @return generated code to put the expression output into an [[UnsafeRow]]
+ */
+ def createCode(ctx: CodeGenContext, ev: GeneratedExpressionCode, expressions: Seq[Expression])
+ : String = {
+
+ val ret = ev.primitive
+ ctx.addMutableState("UnsafeRow", ret, s"$ret = new UnsafeRow();")
+ val bufferTerm = ctx.freshName("buffer")
+ ctx.addMutableState("byte[]", bufferTerm, s"$bufferTerm = new byte[64];")
+ val cursorTerm = ctx.freshName("cursor")
+ val numBytesTerm = ctx.freshName("numBytes")
- protected def create(expressions: Seq[Expression]): UnsafeProjection = {
- val ctx = newCodeGenContext()
val exprs = expressions.map(_.gen(ctx))
val allExprs = exprs.map(_.code).mkString("\n")
val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length)
- val stringWriter = "org.apache.spark.sql.catalyst.expressions.StringUnsafeColumnWriter"
- val binaryWriter = "org.apache.spark.sql.catalyst.expressions.BinaryUnsafeColumnWriter"
+
val additionalSize = expressions.zipWithIndex.map { case (e, i) =>
e.dataType match {
case StringType =>
- s" + (${exprs(i).isNull} ? 0 : $stringWriter.getSize(${exprs(i).primitive}))"
+ s" + (${exprs(i).isNull} ? 0 : $StringWriter.getSize(${exprs(i).primitive}))"
case BinaryType =>
- s" + (${exprs(i).isNull} ? 0 : $binaryWriter.getSize(${exprs(i).primitive}))"
+ s" + (${exprs(i).isNull} ? 0 : $BinaryWriter.getSize(${exprs(i).primitive}))"
case _ => ""
}
}.mkString("")
@@ -58,63 +75,85 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val writers = expressions.zipWithIndex.map { case (e, i) =>
val update = e.dataType match {
case dt if ctx.isPrimitiveType(dt) =>
- s"${ctx.setColumn("target", dt, i, exprs(i).primitive)}"
+ s"${ctx.setColumn(ret, dt, i, exprs(i).primitive)}"
case StringType =>
- s"cursor += $stringWriter.write(target, ${exprs(i).primitive}, $i, cursor)"
+ s"$cursorTerm += $StringWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})"
case BinaryType =>
- s"cursor += $binaryWriter.write(target, ${exprs(i).primitive}, $i, cursor)"
+ s"$cursorTerm += $BinaryWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})"
case NullType => ""
case _ =>
throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}")
}
s"""if (${exprs(i).isNull}) {
- target.setNullAt($i);
+ $ret.setNullAt($i);
} else {
$update;
}"""
}.mkString("\n ")
- val code = s"""
- private $exprType[] expressions;
+ s"""
+ $allExprs
+ int $numBytesTerm = $fixedSize $additionalSize;
+ if ($numBytesTerm > $bufferTerm.length) {
+ $bufferTerm = new byte[$numBytesTerm];
+ }
- public Object generate($exprType[] expr) {
- this.expressions = expr;
- return new SpecificProjection();
- }
+ $ret.pointTo(
+ $bufferTerm,
+ org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET,
+ ${expressions.size},
+ $numBytesTerm);
+ int $cursorTerm = $fixedSize;
- class SpecificProjection extends ${classOf[UnsafeProjection].getName} {
- private UnsafeRow target = new UnsafeRow();
- private byte[] buffer = new byte[64];
- ${declareMutableStates(ctx)}
+ $writers
+ boolean ${ev.isNull} = false;
+ """
+ }
- public SpecificProjection() {
- ${initMutableStates(ctx)}
- }
+ protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
+ in.map(ExpressionCanonicalizer.execute)
+
+ protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
+ in.map(BindReferences.bindReference(_, inputSchema))
+
+ protected def create(expressions: Seq[Expression]): UnsafeProjection = {
+ val ctx = newCodeGenContext()
+
+ val isNull = ctx.freshName("retIsNull")
+ val primitive = ctx.freshName("retValue")
+ val eval = GeneratedExpressionCode("", isNull, primitive)
+ eval.code = createCode(ctx, eval, expressions)
- // Scala.Function1 need this
- public Object apply(Object row) {
- return apply((InternalRow) row);
+ val code = s"""
+ private $exprType[] expressions;
+
+ public Object generate($exprType[] expr) {
+ this.expressions = expr;
+ return new SpecificProjection();
}
- public UnsafeRow apply(InternalRow i) {
- $allExprs
+ class SpecificProjection extends ${classOf[UnsafeProjection].getName} {
+
+ ${declareMutableStates(ctx)}
+
+ public SpecificProjection() {
+ ${initMutableStates(ctx)}
+ }
+
+ // Scala.Function1 need this
+ public Object apply(Object row) {
+ return apply((InternalRow) row);
+ }
- // additionalSize had '+' in the beginning
- int numBytes = $fixedSize $additionalSize;
- if (numBytes > buffer.length) {
- buffer = new byte[numBytes];
+ public UnsafeRow apply(InternalRow i) {
+ ${eval.code}
+ return ${eval.primitive};
}
- target.pointTo(buffer, org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET,
- ${expressions.size}, numBytes);
- int cursor = $fixedSize;
- $writers
- return target;
}
- }
- """
+ """
- logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}")
+ logDebug(s"code for ${expressions.mkString(",")}:\n$code")
val c = compile(code)
c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection]
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index 6e17ffcda9..4930219aa6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -43,7 +43,7 @@ trait ExpressionEvalHelper {
checkEvaluationWithoutCodegen(expression, catalystValue, inputRow)
checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow)
checkEvaluationWithGeneratedProjection(expression, catalystValue, inputRow)
- if (UnsafeColumnWriter.canEmbed(expression.dataType)) {
+ if (GenerateUnsafeProjection.canSupport(expression.dataType)) {
checkEvalutionWithUnsafeProjection(expression, catalystValue, inputRow)
}
checkEvaluationWithOptimization(expression, catalystValue, inputRow)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index a5d9806c20..4606bcb573 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -26,7 +26,6 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.UTF8String
@@ -34,22 +33,15 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
test("basic conversion with only primitive types") {
val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType)
- val converter = new UnsafeRowConverter(fieldTypes)
+ val converter = UnsafeProjection.create(fieldTypes)
val row = new SpecificMutableRow(fieldTypes)
row.setLong(0, 0)
row.setLong(1, 1)
row.setInt(2, 2)
- val sizeRequired: Int = converter.getSizeRequirement(row)
- assert(sizeRequired === 8 + (3 * 8))
- val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
- val numBytesWritten =
- converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired)
- assert(numBytesWritten === sizeRequired)
-
- val unsafeRow = new UnsafeRow()
- unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired)
+ val unsafeRow: UnsafeRow = converter.apply(row)
+ assert(converter.apply(row).getSizeInBytes === 8 + (3 * 8))
assert(unsafeRow.getLong(0) === 0)
assert(unsafeRow.getLong(1) === 1)
assert(unsafeRow.getInt(2) === 2)
@@ -73,25 +65,18 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
test("basic conversion with primitive, string and binary types") {
val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType)
- val converter = new UnsafeRowConverter(fieldTypes)
+ val converter = UnsafeProjection.create(fieldTypes)
val row = new SpecificMutableRow(fieldTypes)
row.setLong(0, 0)
row.update(1, UTF8String.fromString("Hello"))
row.update(2, "World".getBytes)
- val sizeRequired: Int = converter.getSizeRequirement(row)
- assert(sizeRequired === 8 + (8 * 3) +
+ val unsafeRow: UnsafeRow = converter.apply(row)
+ assert(unsafeRow.getSizeInBytes === 8 + (8 * 3) +
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) +
ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length))
- val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
- val numBytesWritten = converter.writeRow(
- row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired)
- assert(numBytesWritten === sizeRequired)
-
- val unsafeRow = new UnsafeRow()
- unsafeRow.pointTo(
- buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired)
+
assert(unsafeRow.getLong(0) === 0)
assert(unsafeRow.getString(1) === "Hello")
assert(unsafeRow.getBinary(2) === "World".getBytes)
@@ -99,7 +84,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
test("basic conversion with primitive, string, date and timestamp types") {
val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType)
- val converter = new UnsafeRowConverter(fieldTypes)
+ val converter = UnsafeProjection.create(fieldTypes)
val row = new SpecificMutableRow(fieldTypes)
row.setLong(0, 0)
@@ -107,17 +92,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
row.update(2, DateTimeUtils.fromJavaDate(Date.valueOf("1970-01-01")))
row.update(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-05-08 08:10:25")))
- val sizeRequired: Int = converter.getSizeRequirement(row)
- assert(sizeRequired === 8 + (8 * 4) +
+ val unsafeRow: UnsafeRow = converter.apply(row)
+ assert(unsafeRow.getSizeInBytes === 8 + (8 * 4) +
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length))
- val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
- val numBytesWritten =
- converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired)
- assert(numBytesWritten === sizeRequired)
-
- val unsafeRow = new UnsafeRow()
- unsafeRow.pointTo(
- buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired)
+
assert(unsafeRow.getLong(0) === 0)
assert(unsafeRow.getString(1) === "Hello")
// Date is represented as Int in unsafeRow
@@ -148,26 +126,18 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
// DecimalType.Default,
// ArrayType(IntegerType)
)
- val converter = new UnsafeRowConverter(fieldTypes)
+ val converter = UnsafeProjection.create(fieldTypes)
val rowWithAllNullColumns: InternalRow = {
val r = new SpecificMutableRow(fieldTypes)
- for (i <- 0 to fieldTypes.length - 1) {
+ for (i <- fieldTypes.indices) {
r.setNullAt(i)
}
r
}
- val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns)
- val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8)
- val numBytesWritten = converter.writeRow(
- rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET,
- sizeRequired)
- assert(numBytesWritten === sizeRequired)
+ val createdFromNull: UnsafeRow = converter.apply(rowWithAllNullColumns)
- val createdFromNull = new UnsafeRow()
- createdFromNull.pointTo(
- createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired)
for (i <- fieldTypes.indices) {
assert(createdFromNull.isNullAt(i))
}
@@ -202,15 +172,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
// r.update(11, Array(11))
r
}
- val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8 + 2)
- converter.writeRow(
- rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET,
- sizeRequired)
- val setToNullAfterCreation = new UnsafeRow()
- setToNullAfterCreation.pointTo(
- setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length,
- sizeRequired)
+ val setToNullAfterCreation = converter.apply(rowWithNoNullColumns)
assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0))
assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1))
assert(setToNullAfterCreation.getByte(2) === rowWithNoNullColumns.getByte(2))
@@ -228,8 +191,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
setToNullAfterCreation.setNullAt(i)
}
// There are some garbage left in the var-length area
- assert(Arrays.equals(createdFromNullBuffer,
- java.util.Arrays.copyOf(setToNullAfterCreationBuffer, sizeRequired / 8)))
+ assert(Arrays.equals(createdFromNull.getBytes, setToNullAfterCreation.getBytes()))
setToNullAfterCreation.setNullAt(0)
setToNullAfterCreation.setBoolean(1, false)
@@ -269,12 +231,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
row2.setFloat(0, java.lang.Float.intBitsToFloat(0x7fffffff))
row2.setDouble(1, java.lang.Double.longBitsToDouble(0x7fffffffffffffffL))
- val converter = new UnsafeRowConverter(fieldTypes)
- val row1Buffer = new Array[Byte](converter.getSizeRequirement(row1))
- val row2Buffer = new Array[Byte](converter.getSizeRequirement(row2))
- converter.writeRow(row1, row1Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row1Buffer.length)
- converter.writeRow(row2, row2Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row2Buffer.length)
-
- assert(row1Buffer.toSeq === row2Buffer.toSeq)
+ val converter = UnsafeProjection.create(fieldTypes)
+ assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
index a1e1695717..40b47ae18d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
@@ -22,29 +22,22 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
-import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeRowConverter}
+import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.PlatformDependent
class UnsafeRowSerializerSuite extends SparkFunSuite {
private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = {
val internalRow = CatalystTypeConverters.convertToCatalyst(row).asInstanceOf[InternalRow]
- val rowConverter = new UnsafeRowConverter(schema)
- val rowSizeInBytes = rowConverter.getSizeRequirement(internalRow)
- val byteArray = new Array[Byte](rowSizeInBytes)
- rowConverter.writeRow(
- internalRow, byteArray, PlatformDependent.BYTE_ARRAY_OFFSET, rowSizeInBytes)
- val unsafeRow = new UnsafeRow()
- unsafeRow.pointTo(byteArray, PlatformDependent.BYTE_ARRAY_OFFSET, row.length, rowSizeInBytes)
- unsafeRow
+ val converter = UnsafeProjection.create(schema)
+ converter.apply(internalRow)
}
- ignore("toUnsafeRow() test helper method") {
+ test("toUnsafeRow() test helper method") {
// This currently doesnt work because the generic getter throws an exception.
val row = Row("Hello", 123)
val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType))
- assert(row.getString(0) === unsafeRow.get(0).toString)
+ assert(row.getString(0) === unsafeRow.getUTF8String(0).toString)
assert(row.getInt(1) === unsafeRow.getInt(1))
}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java
new file mode 100644
index 0000000000..69b0e206ce
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.types;
+
+import org.apache.spark.unsafe.PlatformDependent;
+
+public class ByteArray {
+
+ /**
+ * Writes the content of a byte array into a memory address, identified by an object and an
+ * offset. The target memory address must already been allocated, and have enough space to
+ * hold all the bytes in this string.
+ */
+ public static void writeToMemory(byte[] src, Object target, long targetOffset) {
+ PlatformDependent.copyMemory(
+ src,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ target,
+ targetOffset,
+ src.length
+ );
+ }
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 6d8dcb1cbf..85381cf0ef 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -96,6 +96,21 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
}
/**
+ * Writes the content of this string into a memory address, identified by an object and an offset.
+ * The target memory address must already been allocated, and have enough space to hold all the
+ * bytes in this string.
+ */
+ public void writeToMemory(Object target, long targetOffset) {
+ PlatformDependent.copyMemory(
+ base,
+ offset,
+ target,
+ targetOffset,
+ numBytes
+ );
+ }
+
+ /**
* Returns the number of bytes for a code point with the first byte as `b`
* @param b The first byte of a code point
*/