aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorKazuaki Ishizaki <ishizaki@jp.ibm.com>2016-11-08 00:14:57 +0100
committerHerman van Hovell <hvanhovell@databricks.com>2016-11-08 00:14:57 +0100
commit19cf208063f035d793d2306295a251a9af7e32f6 (patch)
treebaf36494b9b5410e9a94782afdfa6972430b6316 /sql/catalyst
parent8f0ea011a7294679ec4275b2fef349ef45b6eb81 (diff)
downloadspark-19cf208063f035d793d2306295a251a9af7e32f6.tar.gz
spark-19cf208063f035d793d2306295a251a9af7e32f6.tar.bz2
spark-19cf208063f035d793d2306295a251a9af7e32f6.zip
[SPARK-17490][SQL] Optimize SerializeFromObject() for a primitive array
## What changes were proposed in this pull request? Waiting for merging #13680 This PR optimizes `SerializeFromObject()` for an primitive array. This is derived from #13758 to address one of problems by using a simple way in #13758. The current implementation always generates `GenericArrayData` from `SerializeFromObject()` for any type of an array in a logical plan. This involves a boxing at a constructor of `GenericArrayData` when `SerializedFromObject()` has an primitive array. This PR enables to generate `UnsafeArrayData` from `SerializeFromObject()` for a primitive array. It can avoid boxing to create an instance of `ArrayData` in the generated code by Catalyst. This PR also generate `UnsafeArrayData` in a case for `RowEncoder.serializeFor` or `CatalystTypeConverters.createToCatalystConverter`. Performance improvement of `SerializeFromObject()` is up to 2.0x ``` OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 Intel Xeon E3-12xx v2 (Ivy Bridge) Without this PR Write an array in Dataset: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 556 / 608 15.1 66.3 1.0X Double 1668 / 1746 5.0 198.8 0.3X with this PR Write an array in Dataset: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 352 / 401 23.8 42.0 1.0X Double 821 / 885 10.2 97.9 0.4X ``` Here is an example program that will happen in mllib as described in [SPARK-16070](https://issues.apache.org/jira/browse/SPARK-16070). ``` sparkContext.parallelize(Seq(Array(1, 2)), 1).toDS.map(e => e).show ``` Generated code before applying this PR ``` java /* 039 */ protected void processNext() throws java.io.IOException { /* 040 */ while (inputadapter_input.hasNext()) { /* 041 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 042 */ int[] inputadapter_value = (int[])inputadapter_row.get(0, null); /* 043 */ /* 044 */ Object mapelements_obj = ((Expression) references[0]).eval(null); /* 045 */ scala.Function1 mapelements_value1 = (scala.Function1) mapelements_obj; /* 046 */ /* 047 */ boolean mapelements_isNull = false || false; /* 048 */ int[] mapelements_value = null; /* 049 */ if (!mapelements_isNull) { /* 050 */ Object mapelements_funcResult = null; /* 051 */ mapelements_funcResult = mapelements_value1.apply(inputadapter_value); /* 052 */ if (mapelements_funcResult == null) { /* 053 */ mapelements_isNull = true; /* 054 */ } else { /* 055 */ mapelements_value = (int[]) mapelements_funcResult; /* 056 */ } /* 057 */ /* 058 */ } /* 059 */ mapelements_isNull = mapelements_value == null; /* 060 */ /* 061 */ serializefromobject_argIsNulls[0] = mapelements_isNull; /* 062 */ serializefromobject_argValue = mapelements_value; /* 063 */ /* 064 */ boolean serializefromobject_isNull = false; /* 065 */ for (int idx = 0; idx < 1; idx++) { /* 066 */ if (serializefromobject_argIsNulls[idx]) { serializefromobject_isNull = true; break; } /* 067 */ } /* 068 */ /* 069 */ final ArrayData serializefromobject_value = serializefromobject_isNull ? null : new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_argValue); /* 070 */ serializefromobject_holder.reset(); /* 071 */ /* 072 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 073 */ /* 074 */ if (serializefromobject_isNull) { /* 075 */ serializefromobject_rowWriter.setNullAt(0); /* 076 */ } else { /* 077 */ // Remember the current cursor so that we can calculate how many bytes are /* 078 */ // written later. /* 079 */ final int serializefromobject_tmpCursor = serializefromobject_holder.cursor; /* 080 */ /* 081 */ if (serializefromobject_value instanceof UnsafeArrayData) { /* 082 */ final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes(); /* 083 */ // grow the global buffer before writing data. /* 084 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes); /* 085 */ ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 086 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes; /* 087 */ /* 088 */ } else { /* 089 */ final int serializefromobject_numElements = serializefromobject_value.numElements(); /* 090 */ serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 4); /* 091 */ /* 092 */ for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) { /* 093 */ if (serializefromobject_value.isNullAt(serializefromobject_index)) { /* 094 */ serializefromobject_arrayWriter.setNullInt(serializefromobject_index); /* 095 */ } else { /* 096 */ final int serializefromobject_element = serializefromobject_value.getInt(serializefromobject_index); /* 097 */ serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element); /* 098 */ } /* 099 */ } /* 100 */ } /* 101 */ /* 102 */ serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor); /* 103 */ } /* 104 */ serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize()); /* 105 */ append(serializefromobject_result); /* 106 */ if (shouldStop()) return; /* 107 */ } /* 108 */ } /* 109 */ } ``` Generated code after applying this PR ``` java /* 035 */ protected void processNext() throws java.io.IOException { /* 036 */ while (inputadapter_input.hasNext()) { /* 037 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 038 */ int[] inputadapter_value = (int[])inputadapter_row.get(0, null); /* 039 */ /* 040 */ Object mapelements_obj = ((Expression) references[0]).eval(null); /* 041 */ scala.Function1 mapelements_value1 = (scala.Function1) mapelements_obj; /* 042 */ /* 043 */ boolean mapelements_isNull = false || false; /* 044 */ int[] mapelements_value = null; /* 045 */ if (!mapelements_isNull) { /* 046 */ Object mapelements_funcResult = null; /* 047 */ mapelements_funcResult = mapelements_value1.apply(inputadapter_value); /* 048 */ if (mapelements_funcResult == null) { /* 049 */ mapelements_isNull = true; /* 050 */ } else { /* 051 */ mapelements_value = (int[]) mapelements_funcResult; /* 052 */ } /* 053 */ /* 054 */ } /* 055 */ mapelements_isNull = mapelements_value == null; /* 056 */ /* 057 */ boolean serializefromobject_isNull = mapelements_isNull; /* 058 */ final ArrayData serializefromobject_value = serializefromobject_isNull ? null : org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.fromPrimitiveArray(mapelements_value); /* 059 */ serializefromobject_isNull = serializefromobject_value == null; /* 060 */ serializefromobject_holder.reset(); /* 061 */ /* 062 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 063 */ /* 064 */ if (serializefromobject_isNull) { /* 065 */ serializefromobject_rowWriter.setNullAt(0); /* 066 */ } else { /* 067 */ // Remember the current cursor so that we can calculate how many bytes are /* 068 */ // written later. /* 069 */ final int serializefromobject_tmpCursor = serializefromobject_holder.cursor; /* 070 */ /* 071 */ if (serializefromobject_value instanceof UnsafeArrayData) { /* 072 */ final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes(); /* 073 */ // grow the global buffer before writing data. /* 074 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes); /* 075 */ ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 076 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes; /* 077 */ /* 078 */ } else { /* 079 */ final int serializefromobject_numElements = serializefromobject_value.numElements(); /* 080 */ serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 4); /* 081 */ /* 082 */ for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) { /* 083 */ if (serializefromobject_value.isNullAt(serializefromobject_index)) { /* 084 */ serializefromobject_arrayWriter.setNullInt(serializefromobject_index); /* 085 */ } else { /* 086 */ final int serializefromobject_element = serializefromobject_value.getInt(serializefromobject_index); /* 087 */ serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element); /* 088 */ } /* 089 */ } /* 090 */ } /* 091 */ /* 092 */ serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor); /* 093 */ } /* 094 */ serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize()); /* 095 */ append(serializefromobject_result); /* 096 */ if (shouldStop()) return; /* 097 */ } /* 098 */ } /* 099 */ } ``` ## How was this patch tested? Added a test in `DatasetSuite`, `RowEncoderSuite`, and `CatalystTypeConvertersSuite` Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com> Closes #15044 from kiszk/SPARK-17490.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala16
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala27
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala15
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala33
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala26
5 files changed, 103 insertions, 14 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 31c6e5def1..7bcaea7ea2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -441,6 +441,22 @@ object ScalaReflection extends ScalaReflection {
val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath
MapObjects(serializerFor(_, elementType, newPath), input, dt)
+ case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType |
+ FloatType | DoubleType) =>
+ val cls = input.dataType.asInstanceOf[ObjectType].cls
+ if (cls.isArray && cls.getComponentType.isPrimitive) {
+ StaticInvoke(
+ classOf[UnsafeArrayData],
+ ArrayType(dt, false),
+ "fromPrimitiveArray",
+ input :: Nil)
+ } else {
+ NewInstance(
+ classOf[GenericArrayData],
+ input :: Nil,
+ dataType = ArrayType(dt, schemaFor(elementType).nullable))
+ }
+
case dt =>
NewInstance(
classOf[GenericArrayData],
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 2a6fcd03a2..e95e97b9dc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -23,7 +23,7 @@ import scala.reflect.ClassTag
import org.apache.spark.SparkException
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.expressions.objects._
@@ -119,18 +119,19 @@ object RowEncoder {
"fromString",
inputObject :: Nil)
- case t @ ArrayType(et, _) => et match {
- case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType =>
- // TODO: validate input type for primitive array.
- NewInstance(
- classOf[GenericArrayData],
- inputObject :: Nil,
- dataType = t)
- case _ => MapObjects(
- element => serializerFor(ValidateExternalType(element, et), et),
- inputObject,
- ObjectType(classOf[Object]))
- }
+ case t @ ArrayType(et, cn) =>
+ et match {
+ case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType =>
+ StaticInvoke(
+ classOf[ArrayData],
+ t,
+ "toArrayData",
+ inputObject :: Nil)
+ case _ => MapObjects(
+ element => serializerFor(ValidateExternalType(element, et), et),
+ inputObject,
+ ObjectType(classOf[Object]))
+ }
case t @ MapType(kt, vt, valueNullable) =>
val keys =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
index cad4a08b0d..140e86d670 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
@@ -19,9 +19,22 @@ package org.apache.spark.sql.catalyst.util
import scala.reflect.ClassTag
-import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
+import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, UnsafeArrayData}
import org.apache.spark.sql.types.DataType
+object ArrayData {
+ def toArrayData(input: Any): ArrayData = input match {
+ case a: Array[Boolean] => UnsafeArrayData.fromPrimitiveArray(a)
+ case a: Array[Byte] => UnsafeArrayData.fromPrimitiveArray(a)
+ case a: Array[Short] => UnsafeArrayData.fromPrimitiveArray(a)
+ case a: Array[Int] => UnsafeArrayData.fromPrimitiveArray(a)
+ case a: Array[Long] => UnsafeArrayData.fromPrimitiveArray(a)
+ case a: Array[Float] => UnsafeArrayData.fromPrimitiveArray(a)
+ case a: Array[Double] => UnsafeArrayData.fromPrimitiveArray(a)
+ case other => new GenericArrayData(other)
+ }
+}
+
abstract class ArrayData extends SpecializedGetters with Serializable {
def numElements(): Int
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
index 03bb102c67..f3702ec92b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData
+import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._
class CatalystTypeConvertersSuite extends SparkFunSuite {
@@ -61,4 +63,35 @@ class CatalystTypeConvertersSuite extends SparkFunSuite {
test("option handling in createToCatalystConverter") {
assert(CatalystTypeConverters.createToCatalystConverter(IntegerType)(Some(123)) === 123)
}
+
+ test("primitive array handling") {
+ val intArray = Array(1, 100, 10000)
+ val intUnsafeArray = UnsafeArrayData.fromPrimitiveArray(intArray)
+ val intArrayType = ArrayType(IntegerType, false)
+ assert(CatalystTypeConverters.createToScalaConverter(intArrayType)(intUnsafeArray) === intArray)
+
+ val doubleArray = Array(1.1, 111.1, 11111.1)
+ val doubleUnsafeArray = UnsafeArrayData.fromPrimitiveArray(doubleArray)
+ val doubleArrayType = ArrayType(DoubleType, false)
+ assert(CatalystTypeConverters.createToScalaConverter(doubleArrayType)(doubleUnsafeArray)
+ === doubleArray)
+ }
+
+ test("An array with null handling") {
+ val intArray = Array(1, null, 100, null, 10000)
+ val intGenericArray = new GenericArrayData(intArray)
+ val intArrayType = ArrayType(IntegerType, true)
+ assert(CatalystTypeConverters.createToScalaConverter(intArrayType)(intGenericArray)
+ === intArray)
+ assert(CatalystTypeConverters.createToCatalystConverter(intArrayType)(intArray)
+ == intGenericArray)
+
+ val doubleArray = Array(1.1, null, 111.1, null, 11111.1)
+ val doubleGenericArray = new GenericArrayData(doubleArray)
+ val doubleArrayType = ArrayType(DoubleType, true)
+ assert(CatalystTypeConverters.createToScalaConverter(doubleArrayType)(doubleGenericArray)
+ === doubleArray)
+ assert(CatalystTypeConverters.createToCatalystConverter(doubleArrayType)(doubleArray)
+ == doubleGenericArray)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 2e513ea22c..1a5569a77d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -191,6 +191,32 @@ class RowEncoderSuite extends SparkFunSuite {
assert(encoder.serializer.head.nullable == false)
}
+ test("RowEncoder should support primitive arrays") {
+ val schema = new StructType()
+ .add("booleanPrimitiveArray", ArrayType(BooleanType, false))
+ .add("bytePrimitiveArray", ArrayType(ByteType, false))
+ .add("shortPrimitiveArray", ArrayType(ShortType, false))
+ .add("intPrimitiveArray", ArrayType(IntegerType, false))
+ .add("longPrimitiveArray", ArrayType(LongType, false))
+ .add("floatPrimitiveArray", ArrayType(FloatType, false))
+ .add("doublePrimitiveArray", ArrayType(DoubleType, false))
+ val encoder = RowEncoder(schema).resolveAndBind()
+ val input = Seq(
+ Array(true, false),
+ Array(1.toByte, 64.toByte, Byte.MaxValue),
+ Array(1.toShort, 255.toShort, Short.MaxValue),
+ Array(1, 10000, Int.MaxValue),
+ Array(1.toLong, 1000000.toLong, Long.MaxValue),
+ Array(1.1.toFloat, 123.456.toFloat, Float.MaxValue),
+ Array(11.1111, 123456.7890123, Double.MaxValue)
+ )
+ val row = encoder.toRow(Row.fromSeq(input))
+ val convertedBack = encoder.fromRow(row)
+ input.zipWithIndex.map { case (array, index) =>
+ assert(convertedBack.getSeq(index) === array)
+ }
+ }
+
test("RowEncoder should support array as the external type for ArrayType") {
val schema = new StructType()
.add("array", ArrayType(IntegerType))