aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-10-12 21:12:59 -0700
committerCheng Lian <lian@databricks.com>2015-10-12 21:12:59 -0700
commitc4da5345a0ef643a7518756caaa18ff3f3ea9acc (patch)
tree330ed74a4ebe7e98b8983df84d0d91f556b7199e
parentf97e9323b526b3d0b0fee0ca03f4276f37bb5750 (diff)
downloadspark-c4da5345a0ef643a7518756caaa18ff3f3ea9acc.tar.gz
spark-c4da5345a0ef643a7518756caaa18ff3f3ea9acc.tar.bz2
spark-c4da5345a0ef643a7518756caaa18ff3f3ea9acc.zip
[SPARK-10990] [SPARK-11018] [SQL] improve unrolling of complex types
This PR improve the unrolling and read of complex types in columnar cache: 1) Using UnsafeProjection to do serialization of complex types, so they will not be serialized three times (two for actualSize) 2) Copy the bytes from UnsafeRow/UnsafeArrayData to ByteBuffer directly, avoiding the immediate byte[] 3) Using the underlying array in ByteBuffer to create UTF8String/UnsafeRow/UnsafeArrayData without copy. Combine these optimizations, we can reduce the unrolling time from 25s to 21s (20% less), reduce the scanning time from 3.5s to 2.5s (28% less). ``` df = sqlContext.read.parquet(path) t = time.time() df.cache() df.count() print 'unrolling', time.time() - t for i in range(10): t = time.time() print df.select("*")._jdf.queryExecution().toRdd().count() print time.time() - t ``` The schema is ``` root |-- a: struct (nullable = true) | |-- b: long (nullable = true) | |-- c: string (nullable = true) |-- d: array (nullable = true) | |-- element: long (containsNull = true) |-- e: map (nullable = true) | |-- key: long | |-- value: string (valueContainsNull = true) ``` Now the columnar cache depends on that UnsafeProjection support all the data types (including UDT), this PR also fix that. Author: Davies Liu <davies@databricks.com> Closes #9016 from davies/complex2.
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java12
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala29
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala187
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala37
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala13
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java10
12 files changed, 188 insertions, 140 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
index fdd9125613..796f8abec9 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions;
import java.math.BigDecimal;
import java.math.BigInteger;
+import java.nio.ByteBuffer;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.Platform;
@@ -145,6 +146,8 @@ public class UnsafeArrayData extends ArrayData {
return getArray(ordinal);
} else if (dataType instanceof MapType) {
return getMap(ordinal);
+ } else if (dataType instanceof UserDefinedType) {
+ return get(ordinal, ((UserDefinedType)dataType).sqlType());
} else {
throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString());
}
@@ -306,6 +309,15 @@ public class UnsafeArrayData extends ArrayData {
Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes);
}
+ public void writeTo(ByteBuffer buffer) {
+ assert(buffer.hasArray());
+ byte[] target = buffer.array();
+ int offset = buffer.arrayOffset();
+ int pos = buffer.position();
+ writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos);
+ buffer.position(pos + sizeInBytes);
+ }
+
@Override
public UnsafeArrayData copy() {
UnsafeArrayData arrayCopy = new UnsafeArrayData();
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 5af7ed5d6e..36859fbab9 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions;
import java.io.*;
import java.math.BigDecimal;
import java.math.BigInteger;
+import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
@@ -326,6 +327,8 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
return getArray(ordinal);
} else if (dataType instanceof MapType) {
return getMap(ordinal);
+ } else if (dataType instanceof UserDefinedType) {
+ return get(ordinal, ((UserDefinedType)dataType).sqlType());
} else {
throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString());
}
@@ -602,6 +605,15 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes);
}
+ public void writeTo(ByteBuffer buffer) {
+ assert (buffer.hasArray());
+ byte[] target = buffer.array();
+ int offset = buffer.arrayOffset();
+ int pos = buffer.position();
+ writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos);
+ buffer.position(pos + sizeInBytes);
+ }
+
@Override
public void writeExternal(ObjectOutput out) throws IOException {
byte[] bytes = getBytes();
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index a0fe5bd77e..7544d27e3d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -129,6 +129,7 @@ class CodeGenContext {
case _: ArrayType => s"$input.getArray($ordinal)"
case _: MapType => s"$input.getMap($ordinal)"
case NullType => "null"
+ case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal)
case _ => s"($jt)$input.get($ordinal, null)"
}
}
@@ -143,6 +144,7 @@ class CodeGenContext {
case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})"
// The UTF8String may came from UnsafeRow, otherwise clone is cheap (re-use the bytes)
case StringType => s"$row.update($ordinal, $value.clone())"
+ case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value)
case _ => s"$row.update($ordinal, $value)"
}
}
@@ -177,6 +179,7 @@ class CodeGenContext {
case _: MapType => "MapData"
case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName
case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName
+ case udt: UserDefinedType[_] => javaType(udt.sqlType)
case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]"
case ObjectType(cls) => cls.getName
case _ => "Object"
@@ -222,6 +225,7 @@ class CodeGenContext {
case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2"
case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2"
case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2"
+ case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2)
case other => s"$c1.equals($c2)"
}
@@ -255,6 +259,7 @@ class CodeGenContext {
addNewFunction(compareFunc, funcCode)
s"this.$compareFunc($c1, $c2)"
case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)"
+ case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2)
case _ =>
throw new IllegalArgumentException("cannot generate compare code for un-comparable type")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index 9873630937..ee50587ed0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -124,6 +124,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType)
// UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe.
case StringType => GeneratedExpressionCode("", "false", s"$input.clone()")
+ case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType)
case _ => GeneratedExpressionCode("", "false", input)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 3e0e81733f..1b957a508d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -39,6 +39,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case t: StructType => t.toSeq.forall(field => canSupport(field.dataType))
case t: ArrayType if canSupport(t.elementType) => true
case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true
+ case dt: OpenHashSetUDT => false // it's not a standard UDT
+ case udt: UserDefinedType[_] => canSupport(udt.sqlType)
case _ => false
}
@@ -77,7 +79,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
ctx.addMutableState(rowWriterClass, rowWriter, s"this.$rowWriter = new $rowWriterClass();")
val writeFields = inputs.zip(inputTypes).zipWithIndex.map {
- case ((input, dt), index) =>
+ case ((input, dataType), index) =>
+ val dt = dataType match {
+ case udt: UserDefinedType[_] => udt.sqlType
+ case other => other
+ }
val tmpCursor = ctx.freshName("tmpCursor")
val setNull = dt match {
@@ -167,15 +173,20 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val index = ctx.freshName("index")
val element = ctx.freshName("element")
- val jt = ctx.javaType(elementType)
+ val et = elementType match {
+ case udt: UserDefinedType[_] => udt.sqlType
+ case other => other
+ }
+
+ val jt = ctx.javaType(et)
- val fixedElementSize = elementType match {
+ val fixedElementSize = et match {
case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => 8
- case _ if ctx.isPrimitiveType(jt) => elementType.defaultSize
+ case _ if ctx.isPrimitiveType(jt) => et.defaultSize
case _ => 0
}
- val writeElement = elementType match {
+ val writeElement = et match {
case t: StructType =>
s"""
$arrayWriter.setOffset($index);
@@ -194,13 +205,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)}
"""
- case _ if ctx.isPrimitiveType(elementType) =>
+ case _ if ctx.isPrimitiveType(et) =>
// Should we do word align?
- val dataSize = elementType.defaultSize
+ val dataSize = et.defaultSize
s"""
$arrayWriter.setOffset($index);
- ${writePrimitiveType(ctx, element, elementType,
+ ${writePrimitiveType(ctx, element, et,
s"$bufferHolder.buffer", s"$bufferHolder.cursor")}
$bufferHolder.cursor += $dataSize;
"""
@@ -237,7 +248,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
if ($input.isNullAt($index)) {
$arrayWriter.setNullAt($index);
} else {
- final $jt $element = ${ctx.getValue(input, elementType, index)};
+ final $jt $element = ${ctx.getValue(input, et, index)};
$writeElement
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
index 62478667eb..42ec4d3433 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
@@ -19,8 +19,7 @@ package org.apache.spark.sql.columnar
import java.nio.{ByteBuffer, ByteOrder}
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.MutableRow
+import org.apache.spark.sql.catalyst.expressions.{MutableRow, UnsafeArrayData, UnsafeMapData, UnsafeRow}
import org.apache.spark.sql.columnar.compression.CompressibleColumnAccessor
import org.apache.spark.sql.types._
@@ -109,15 +108,15 @@ private[sql] class DecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalTy
with NullableColumnAccessor
private[sql] class StructColumnAccessor(buffer: ByteBuffer, dataType: StructType)
- extends BasicColumnAccessor[InternalRow](buffer, STRUCT(dataType))
+ extends BasicColumnAccessor[UnsafeRow](buffer, STRUCT(dataType))
with NullableColumnAccessor
private[sql] class ArrayColumnAccessor(buffer: ByteBuffer, dataType: ArrayType)
- extends BasicColumnAccessor[ArrayData](buffer, ARRAY(dataType))
+ extends BasicColumnAccessor[UnsafeArrayData](buffer, ARRAY(dataType))
with NullableColumnAccessor
private[sql] class MapColumnAccessor(buffer: ByteBuffer, dataType: MapType)
- extends BasicColumnAccessor[MapData](buffer, MAP(dataType))
+ extends BasicColumnAccessor[UnsafeMapData](buffer, MAP(dataType))
with NullableColumnAccessor
private[sql] object ColumnAccessor {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index 3563eacb3a..2bc2c96b61 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.columnar
import java.math.{BigDecimal, BigInteger}
-import java.nio.{ByteOrder, ByteBuffer}
+import java.nio.ByteBuffer
import scala.reflect.runtime.universe.TypeTag
@@ -92,7 +92,7 @@ private[sql] sealed abstract class ColumnType[JvmType] {
* boxing/unboxing costs whenever possible.
*/
def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
- to.update(toOrdinal, from.get(fromOrdinal, dataType))
+ setField(to, toOrdinal, getField(from, fromOrdinal))
}
/**
@@ -147,6 +147,7 @@ private[sql] object INT extends NativeColumnType(IntegerType, 4) {
override def getField(row: InternalRow, ordinal: Int): Int = row.getInt(ordinal)
+
override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) {
to.setInt(toOrdinal, from.getInt(fromOrdinal))
}
@@ -324,15 +325,18 @@ private[sql] object STRING extends NativeColumnType(StringType, 8) {
}
override def append(v: UTF8String, buffer: ByteBuffer): Unit = {
- val stringBytes = v.getBytes
- buffer.putInt(stringBytes.length).put(stringBytes, 0, stringBytes.length)
+ buffer.putInt(v.numBytes())
+ v.writeTo(buffer)
}
override def extract(buffer: ByteBuffer): UTF8String = {
val length = buffer.getInt()
- val stringBytes = new Array[Byte](length)
- buffer.get(stringBytes, 0, length)
- UTF8String.fromBytes(stringBytes)
+ assert(buffer.hasArray)
+ val base = buffer.array()
+ val offset = buffer.arrayOffset()
+ val cursor = buffer.position()
+ buffer.position(cursor + length)
+ UTF8String.fromBytes(base, offset + cursor, length)
}
override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = {
@@ -386,11 +390,6 @@ private[sql] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize:
def serialize(value: JvmType): Array[Byte]
def deserialize(bytes: Array[Byte]): JvmType
- override def actualSize(row: InternalRow, ordinal: Int): Int = {
- // TODO: grow the buffer in append(), so serialize() will not be called twice
- serialize(getField(row, ordinal)).length + 4
- }
-
override def append(v: JvmType, buffer: ByteBuffer): Unit = {
val bytes = serialize(v)
buffer.putInt(bytes.length).put(bytes, 0, bytes.length)
@@ -416,6 +415,10 @@ private[sql] object BINARY extends ByteArrayColumnType[Array[Byte]](16) {
row.getBinary(ordinal)
}
+ override def actualSize(row: InternalRow, ordinal: Int): Int = {
+ row.getBinary(ordinal).length + 4
+ }
+
def serialize(value: Array[Byte]): Array[Byte] = value
def deserialize(bytes: Array[Byte]): Array[Byte] = bytes
}
@@ -433,6 +436,10 @@ private[sql] case class LARGE_DECIMAL(precision: Int, scale: Int)
row.setDecimal(ordinal, value, precision)
}
+ override def actualSize(row: InternalRow, ordinal: Int): Int = {
+ 4 + getField(row, ordinal).toJavaBigDecimal.unscaledValue().bitLength() / 8 + 1
+ }
+
override def serialize(value: Decimal): Array[Byte] = {
value.toJavaBigDecimal.unscaledValue().toByteArray
}
@@ -449,124 +456,118 @@ private[sql] object LARGE_DECIMAL {
}
}
-private[sql] case class STRUCT(dataType: StructType)
- extends ByteArrayColumnType[InternalRow](20) {
+private[sql] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRow] {
- private val projection: UnsafeProjection =
- UnsafeProjection.create(dataType)
private val numOfFields: Int = dataType.fields.size
- override def setField(row: MutableRow, ordinal: Int, value: InternalRow): Unit = {
+ override def defaultSize: Int = 20
+
+ override def setField(row: MutableRow, ordinal: Int, value: UnsafeRow): Unit = {
row.update(ordinal, value)
}
- override def getField(row: InternalRow, ordinal: Int): InternalRow = {
- row.getStruct(ordinal, numOfFields)
+ override def getField(row: InternalRow, ordinal: Int): UnsafeRow = {
+ row.getStruct(ordinal, numOfFields).asInstanceOf[UnsafeRow]
}
- override def serialize(value: InternalRow): Array[Byte] = {
- val unsafeRow = if (value.isInstanceOf[UnsafeRow]) {
- value.asInstanceOf[UnsafeRow]
- } else {
- projection(value)
- }
- unsafeRow.getBytes
+ override def actualSize(row: InternalRow, ordinal: Int): Int = {
+ 4 + getField(row, ordinal).getSizeInBytes
}
- override def deserialize(bytes: Array[Byte]): InternalRow = {
+ override def append(value: UnsafeRow, buffer: ByteBuffer): Unit = {
+ buffer.putInt(value.getSizeInBytes)
+ value.writeTo(buffer)
+ }
+
+ override def extract(buffer: ByteBuffer): UnsafeRow = {
+ val sizeInBytes = buffer.getInt()
+ assert(buffer.hasArray)
+ val base = buffer.array()
+ val offset = buffer.arrayOffset()
+ val cursor = buffer.position()
+ buffer.position(cursor + sizeInBytes)
val unsafeRow = new UnsafeRow
- unsafeRow.pointTo(bytes, numOfFields, bytes.length)
+ unsafeRow.pointTo(base, Platform.BYTE_ARRAY_OFFSET + offset + cursor, numOfFields, sizeInBytes)
unsafeRow
}
- override def clone(v: InternalRow): InternalRow = v.copy()
+ override def clone(v: UnsafeRow): UnsafeRow = v.copy()
}
-private[sql] case class ARRAY(dataType: ArrayType)
- extends ByteArrayColumnType[ArrayData](16) {
+private[sql] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArrayData] {
- private lazy val projection = UnsafeProjection.create(Array[DataType](dataType))
- private val mutableRow = new GenericMutableRow(new Array[Any](1))
+ override def defaultSize: Int = 16
- override def setField(row: MutableRow, ordinal: Int, value: ArrayData): Unit = {
+ override def setField(row: MutableRow, ordinal: Int, value: UnsafeArrayData): Unit = {
row.update(ordinal, value)
}
- override def getField(row: InternalRow, ordinal: Int): ArrayData = {
- row.getArray(ordinal)
+ override def getField(row: InternalRow, ordinal: Int): UnsafeArrayData = {
+ row.getArray(ordinal).asInstanceOf[UnsafeArrayData]
}
- override def serialize(value: ArrayData): Array[Byte] = {
- val unsafeArray = if (value.isInstanceOf[UnsafeArrayData]) {
- value.asInstanceOf[UnsafeArrayData]
- } else {
- mutableRow(0) = value
- projection(mutableRow).getArray(0)
- }
- val outputBuffer =
- ByteBuffer.allocate(4 + unsafeArray.getSizeInBytes).order(ByteOrder.nativeOrder())
- outputBuffer.putInt(unsafeArray.numElements())
- val underlying = outputBuffer.array()
- unsafeArray.writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 4)
- underlying
+ override def actualSize(row: InternalRow, ordinal: Int): Int = {
+ val unsafeArray = getField(row, ordinal)
+ 4 + 4 + unsafeArray.getSizeInBytes
}
- override def deserialize(bytes: Array[Byte]): ArrayData = {
- val buffer = ByteBuffer.wrap(bytes).order(ByteOrder.nativeOrder())
- val numElements = buffer.getInt
- val array = new UnsafeArrayData
- array.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + 4, numElements, bytes.length - 4)
- array
+ override def append(value: UnsafeArrayData, buffer: ByteBuffer): Unit = {
+ buffer.putInt(4 + value.getSizeInBytes)
+ buffer.putInt(value.numElements())
+ value.writeTo(buffer)
}
- override def clone(v: ArrayData): ArrayData = v.copy()
+ override def extract(buffer: ByteBuffer): UnsafeArrayData = {
+ val numBytes = buffer.getInt
+ assert(buffer.hasArray)
+ val cursor = buffer.position()
+ buffer.position(cursor + numBytes)
+ UnsafeReaders.readArray(
+ buffer.array(),
+ Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor,
+ numBytes)
+ }
+
+ override def clone(v: UnsafeArrayData): UnsafeArrayData = v.copy()
}
-private[sql] case class MAP(dataType: MapType) extends ByteArrayColumnType[MapData](32) {
+private[sql] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData] {
- private lazy val projection: UnsafeProjection = UnsafeProjection.create(Array[DataType](dataType))
- private val mutableRow = new GenericMutableRow(new Array[Any](1))
+ override def defaultSize: Int = 32
- override def setField(row: MutableRow, ordinal: Int, value: MapData): Unit = {
+ override def setField(row: MutableRow, ordinal: Int, value: UnsafeMapData): Unit = {
row.update(ordinal, value)
}
- override def getField(row: InternalRow, ordinal: Int): MapData = {
- row.getMap(ordinal)
+ override def getField(row: InternalRow, ordinal: Int): UnsafeMapData = {
+ row.getMap(ordinal).asInstanceOf[UnsafeMapData]
}
- override def serialize(value: MapData): Array[Byte] = {
- val unsafeMap = if (value.isInstanceOf[UnsafeMapData]) {
- value.asInstanceOf[UnsafeMapData]
- } else {
- mutableRow(0) = value
- projection(mutableRow).getMap(0)
- }
+ override def actualSize(row: InternalRow, ordinal: Int): Int = {
+ val unsafeMap = getField(row, ordinal)
+ 12 + unsafeMap.keyArray().getSizeInBytes + unsafeMap.valueArray().getSizeInBytes
+ }
+
+ override def append(value: UnsafeMapData, buffer: ByteBuffer): Unit = {
+ buffer.putInt(8 + value.keyArray().getSizeInBytes + value.valueArray().getSizeInBytes)
+ buffer.putInt(value.numElements())
+ buffer.putInt(value.keyArray().getSizeInBytes)
+ value.keyArray().writeTo(buffer)
+ value.valueArray().writeTo(buffer)
+ }
+
+ override def extract(buffer: ByteBuffer): UnsafeMapData = {
+ val numBytes = buffer.getInt
+ assert(buffer.hasArray)
+ val cursor = buffer.position()
+ buffer.position(cursor + numBytes)
+ UnsafeReaders.readMap(
+ buffer.array(),
+ Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor,
+ numBytes)
+ }
- val outputBuffer =
- ByteBuffer.allocate(8 + unsafeMap.getSizeInBytes).order(ByteOrder.nativeOrder())
- outputBuffer.putInt(unsafeMap.numElements())
- val keyBytes = unsafeMap.keyArray().getSizeInBytes
- outputBuffer.putInt(keyBytes)
- val underlying = outputBuffer.array()
- unsafeMap.keyArray().writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 8)
- unsafeMap.valueArray().writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 8 + keyBytes)
- underlying
- }
-
- override def deserialize(bytes: Array[Byte]): MapData = {
- val buffer = ByteBuffer.wrap(bytes).order(ByteOrder.nativeOrder())
- val numElements = buffer.getInt
- val keyArraySize = buffer.getInt
- val keyArray = new UnsafeArrayData
- val valueArray = new UnsafeArrayData
- keyArray.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + 8, numElements, keyArraySize)
- valueArray.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + 8 + keyArraySize, numElements,
- bytes.length - 8 - keyArraySize)
- new UnsafeMapData(keyArray, valueArray)
- }
-
- override def clone(v: MapData): MapData = v.copy()
+ override def clone(v: UnsafeMapData): UnsafeMapData = v.copy()
}
private[sql] object ColumnType {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index d7e145f9c2..d967814f62 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
-import org.apache.spark.sql.execution.{LeafNode, SparkPlan}
+import org.apache.spark.sql.execution.{ConvertToUnsafe, LeafNode, SparkPlan}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.{Accumulable, Accumulator, Accumulators}
@@ -38,7 +38,9 @@ private[sql] object InMemoryRelation {
storageLevel: StorageLevel,
child: SparkPlan,
tableName: Option[String]): InMemoryRelation =
- new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)()
+ new InMemoryRelation(child.output, useCompression, batchSize, storageLevel,
+ if (child.outputsUnsafeRows) child else ConvertToUnsafe(child),
+ tableName)()
}
private[sql] case class CachedBatch(buffers: Array[Array[Byte]], stats: InternalRow)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
index ceb8ad97bb..0e6e1bcf72 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
@@ -17,11 +17,11 @@
package org.apache.spark.sql.columnar
-import java.nio.ByteBuffer
+import java.nio.{ByteOrder, ByteBuffer}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.CatalystTypeConverters
-import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow}
import org.apache.spark.sql.columnar.ColumnarTestUtils._
import org.apache.spark.sql.types._
import org.apache.spark.{Logging, SparkFunSuite}
@@ -55,7 +55,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
assertResult(expected, s"Wrong actualSize for $columnType") {
val row = new GenericMutableRow(1)
row.update(0, CatalystTypeConverters.convertToCatalyst(value))
- columnType.actualSize(row, 0)
+ val proj = UnsafeProjection.create(Array[DataType](columnType.dataType))
+ columnType.actualSize(proj(row), 0)
}
}
@@ -99,35 +100,27 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
def testColumnType[JvmType](columnType: ColumnType[JvmType]): Unit = {
- val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE)
- val seq = (0 until 4).map(_ => makeRandomValue(columnType))
+ val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE).order(ByteOrder.nativeOrder())
+ val proj = UnsafeProjection.create(Array[DataType](columnType.dataType))
val converter = CatalystTypeConverters.createToScalaConverter(columnType.dataType)
+ val seq = (0 until 4).map(_ => proj(makeRandomRow(columnType)).copy())
test(s"$columnType append/extract") {
buffer.rewind()
- seq.foreach(columnType.append(_, buffer))
+ seq.foreach(columnType.append(_, 0, buffer))
buffer.rewind()
- seq.foreach { expected =>
- logInfo("buffer = " + buffer + ", expected = " + expected)
- val extracted = columnType.extract(buffer)
- assert(
- converter(expected) === converter(extracted),
- "Extracted value didn't equal to the original one. " +
- hexDump(expected) + " != " + hexDump(extracted) +
- ", buffer = " + dumpBuffer(buffer.duplicate().rewind().asInstanceOf[ByteBuffer]))
+ seq.foreach { row =>
+ logInfo("buffer = " + buffer + ", expected = " + row)
+ val expected = converter(row.get(0, columnType.dataType))
+ val extracted = converter(columnType.extract(buffer))
+ assert(expected === extracted,
+ s"Extracted value didn't equal to the original one. $expected != $extracted, buffer =" +
+ dumpBuffer(buffer.duplicate().rewind().asInstanceOf[ByteBuffer]))
}
}
}
- private def hexDump(value: Any): String = {
- if (value == null) {
- ""
- } else {
- value.toString.map(ch => Integer.toHexString(ch & 0xffff)).mkString(" ")
- }
- }
-
private def dumpBuffer(buff: ByteBuffer): Any = {
val sb = new StringBuilder()
while (buff.hasRemaining) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
index 78cebbf3cc..aa1605fee8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
@@ -21,7 +21,7 @@ import java.nio.ByteBuffer
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.CatalystTypeConverters
-import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow}
import org.apache.spark.sql.types._
class TestNullableColumnAccessor[JvmType](
@@ -64,10 +64,11 @@ class NullableColumnAccessorSuite extends SparkFunSuite {
test(s"Nullable $typeName column accessor: access null values") {
val builder = TestNullableColumnBuilder(columnType)
val randomRow = makeRandomRow(columnType)
+ val proj = UnsafeProjection.create(Array[DataType](columnType.dataType))
(0 until 4).foreach { _ =>
- builder.appendFrom(randomRow, 0)
- builder.appendFrom(nullRow, 0)
+ builder.appendFrom(proj(randomRow), 0)
+ builder.appendFrom(proj(nullRow), 0)
}
val accessor = TestNullableColumnAccessor(builder.build(), columnType)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
index fba08e626d..9140457783 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.columnar
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.CatalystTypeConverters
-import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow}
import org.apache.spark.sql.types._
class TestNullableColumnBuilder[JvmType](columnType: ColumnType[JvmType])
@@ -51,6 +51,9 @@ class NullableColumnBuilderSuite extends SparkFunSuite {
columnType: ColumnType[JvmType]): Unit = {
val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
+ val dataType = columnType.dataType
+ val proj = UnsafeProjection.create(Array[DataType](dataType))
+ val converter = CatalystTypeConverters.createToScalaConverter(dataType)
test(s"$typeName column builder: empty column") {
val columnBuilder = TestNullableColumnBuilder(columnType)
@@ -65,7 +68,7 @@ class NullableColumnBuilderSuite extends SparkFunSuite {
val randomRow = makeRandomRow(columnType)
(0 until 4).foreach { _ =>
- columnBuilder.appendFrom(randomRow, 0)
+ columnBuilder.appendFrom(proj(randomRow), 0)
}
val buffer = columnBuilder.build()
@@ -77,12 +80,10 @@ class NullableColumnBuilderSuite extends SparkFunSuite {
val columnBuilder = TestNullableColumnBuilder(columnType)
val randomRow = makeRandomRow(columnType)
val nullRow = makeNullRow(1)
- val dataType = columnType.dataType
- val converter = CatalystTypeConverters.createToScalaConverter(dataType)
(0 until 4).foreach { _ =>
- columnBuilder.appendFrom(randomRow, 0)
- columnBuilder.appendFrom(nullRow, 0)
+ columnBuilder.appendFrom(proj(randomRow), 0)
+ columnBuilder.appendFrom(proj(nullRow), 0)
}
val buffer = columnBuilder.build()
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 216aeea60d..b7aecb5102 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -19,6 +19,7 @@ package org.apache.spark.unsafe.types;
import javax.annotation.Nonnull;
import java.io.*;
+import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.Map;
@@ -137,6 +138,15 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable
Platform.copyMemory(base, offset, target, targetOffset, numBytes);
}
+ public void writeTo(ByteBuffer buffer) {
+ assert(buffer.hasArray());
+ byte[] target = buffer.array();
+ int offset = buffer.arrayOffset();
+ int pos = buffer.position();
+ writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos);
+ buffer.position(pos + numBytes);
+ }
+
/**
* Returns the number of bytes for a code point with the first byte as `b`
* @param b The first byte of a code point