aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src
diff options
context:
space:
mode:
authorKazuaki Ishizaki <ishizaki@jp.ibm.com>2016-12-29 10:59:37 +0800
committerWenchen Fan <wenchen@databricks.com>2016-12-29 10:59:37 +0800
commit93f35569fd4e7dc1e4037d3df538a21c526f9c5d (patch)
tree0b2c59ed90ca9ae1ccb33a3956ac5a69dcdd344e /sql/catalyst/src
parent092c6725bf039bf33299b53791e1958c4ea3f6aa (diff)
downloadspark-93f35569fd4e7dc1e4037d3df538a21c526f9c5d.tar.gz
spark-93f35569fd4e7dc1e4037d3df538a21c526f9c5d.tar.bz2
spark-93f35569fd4e7dc1e4037d3df538a21c526f9c5d.zip
[SPARK-16213][SQL] Reduce runtime overhead of a program that creates an primitive array in DataFrame
## What changes were proposed in this pull request? This PR reduces runtime overhead of a program the creates an primitive array in DataFrame by using the similar approach to #15044. Generated code performs boxing operation in an assignment from InternalRow to an `Object[]` temporary array (at Lines 051 and 061 in the generated code before without this PR). If we know that type of array elements is primitive, we apply the following optimizations: 1. Eliminate a pair of `isNullAt()` and a null assignment 2. Allocate an primitive array instead of `Object[]` (eliminate boxing operations) 3. Create `UnsafeArrayData` by using `UnsafeArrayWriter` to keep a primitive array in a row format instead of doing non-lightweight operations in constructor of `GenericArrayData` The PR also performs the same things for `CreateMap`. Here are performance results of [DataFrame programs](https://github.com/kiszk/spark/blob/6bf54ec5e227689d69f6db991e9ecbc54e153d0a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala#L83-L112) by up to 17.9x over without this PR. ``` Without SPARK-16043 OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 Intel Xeon E3-12xx v2 (Ivy Bridge) Read a primitive array in DataFrame: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 3805 / 4150 0.0 507308.9 1.0X Double 3593 / 3852 0.0 479056.9 1.1X With SPARK-16043 Read a primitive array in DataFrame: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 213 / 271 0.0 28387.5 1.0X Double 204 / 223 0.0 27250.9 1.0X ``` Note : #15780 is enabled for these measurements An motivating example ``` java val df = sparkContext.parallelize(Seq(0.0d, 1.0d), 1).toDF df.selectExpr("Array(value + 1.1d, value + 2.2d)").show ``` Generated code without this PR ``` java /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private UnsafeRow serializefromobject_result; /* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder; /* 011 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter; /* 012 */ private Object[] project_values; /* 013 */ private UnsafeRow project_result; /* 014 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder project_holder; /* 015 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter project_rowWriter; /* 016 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter project_arrayWriter; /* 017 */ /* 018 */ public GeneratedIterator(Object[] references) { /* 019 */ this.references = references; /* 020 */ } /* 021 */ /* 022 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 023 */ partitionIndex = index; /* 024 */ this.inputs = inputs; /* 025 */ inputadapter_input = inputs[0]; /* 026 */ serializefromobject_result = new UnsafeRow(1); /* 027 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 0); /* 028 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1); /* 029 */ this.project_values = null; /* 030 */ project_result = new UnsafeRow(1); /* 031 */ this.project_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(project_result, 32); /* 032 */ this.project_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(project_holder, 1); /* 033 */ this.project_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(); /* 034 */ /* 035 */ } /* 036 */ /* 037 */ protected void processNext() throws java.io.IOException { /* 038 */ while (inputadapter_input.hasNext()) { /* 039 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 040 */ double inputadapter_value = inputadapter_row.getDouble(0); /* 041 */ /* 042 */ final boolean project_isNull = false; /* 043 */ this.project_values = new Object[2]; /* 044 */ boolean project_isNull1 = false; /* 045 */ /* 046 */ double project_value1 = -1.0; /* 047 */ project_value1 = inputadapter_value + 1.1D; /* 048 */ if (false) { /* 049 */ project_values[0] = null; /* 050 */ } else { /* 051 */ project_values[0] = project_value1; /* 052 */ } /* 053 */ /* 054 */ boolean project_isNull4 = false; /* 055 */ /* 056 */ double project_value4 = -1.0; /* 057 */ project_value4 = inputadapter_value + 2.2D; /* 058 */ if (false) { /* 059 */ project_values[1] = null; /* 060 */ } else { /* 061 */ project_values[1] = project_value4; /* 062 */ } /* 063 */ /* 064 */ final ArrayData project_value = new org.apache.spark.sql.catalyst.util.GenericArrayData(project_values); /* 065 */ this.project_values = null; /* 066 */ project_holder.reset(); /* 067 */ /* 068 */ project_rowWriter.zeroOutNullBytes(); /* 069 */ /* 070 */ if (project_isNull) { /* 071 */ project_rowWriter.setNullAt(0); /* 072 */ } else { /* 073 */ // Remember the current cursor so that we can calculate how many bytes are /* 074 */ // written later. /* 075 */ final int project_tmpCursor = project_holder.cursor; /* 076 */ /* 077 */ if (project_value instanceof UnsafeArrayData) { /* 078 */ final int project_sizeInBytes = ((UnsafeArrayData) project_value).getSizeInBytes(); /* 079 */ // grow the global buffer before writing data. /* 080 */ project_holder.grow(project_sizeInBytes); /* 081 */ ((UnsafeArrayData) project_value).writeToMemory(project_holder.buffer, project_holder.cursor); /* 082 */ project_holder.cursor += project_sizeInBytes; /* 083 */ /* 084 */ } else { /* 085 */ final int project_numElements = project_value.numElements(); /* 086 */ project_arrayWriter.initialize(project_holder, project_numElements, 8); /* 087 */ /* 088 */ for (int project_index = 0; project_index < project_numElements; project_index++) { /* 089 */ if (project_value.isNullAt(project_index)) { /* 090 */ project_arrayWriter.setNullDouble(project_index); /* 091 */ } else { /* 092 */ final double project_element = project_value.getDouble(project_index); /* 093 */ project_arrayWriter.write(project_index, project_element); /* 094 */ } /* 095 */ } /* 096 */ } /* 097 */ /* 098 */ project_rowWriter.setOffsetAndSize(0, project_tmpCursor, project_holder.cursor - project_tmpCursor); /* 099 */ } /* 100 */ project_result.setTotalSize(project_holder.totalSize()); /* 101 */ append(project_result); /* 102 */ if (shouldStop()) return; /* 103 */ } /* 104 */ } /* 105 */ } ``` Generated code with this PR ``` java /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private UnsafeRow serializefromobject_result; /* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder; /* 011 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter; /* 012 */ private UnsafeArrayData project_arrayData; /* 013 */ private UnsafeRow project_result; /* 014 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder project_holder; /* 015 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter project_rowWriter; /* 016 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter project_arrayWriter; /* 017 */ /* 018 */ public GeneratedIterator(Object[] references) { /* 019 */ this.references = references; /* 020 */ } /* 021 */ /* 022 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 023 */ partitionIndex = index; /* 024 */ this.inputs = inputs; /* 025 */ inputadapter_input = inputs[0]; /* 026 */ serializefromobject_result = new UnsafeRow(1); /* 027 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 0); /* 028 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1); /* 029 */ /* 030 */ project_result = new UnsafeRow(1); /* 031 */ this.project_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(project_result, 32); /* 032 */ this.project_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(project_holder, 1); /* 033 */ this.project_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(); /* 034 */ /* 035 */ } /* 036 */ /* 037 */ protected void processNext() throws java.io.IOException { /* 038 */ while (inputadapter_input.hasNext()) { /* 039 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 040 */ double inputadapter_value = inputadapter_row.getDouble(0); /* 041 */ /* 042 */ byte[] project_array = new byte[32]; /* 043 */ project_arrayData = new UnsafeArrayData(); /* 044 */ Platform.putLong(project_array, 16, 2); /* 045 */ project_arrayData.pointTo(project_array, 16, 32); /* 046 */ /* 047 */ boolean project_isNull1 = false; /* 048 */ /* 049 */ double project_value1 = -1.0; /* 050 */ project_value1 = inputadapter_value + 1.1D; /* 051 */ if (false) { /* 052 */ project_arrayData.setNullAt(0); /* 053 */ } else { /* 054 */ project_arrayData.setDouble(0, project_value1); /* 055 */ } /* 056 */ /* 057 */ boolean project_isNull4 = false; /* 058 */ /* 059 */ double project_value4 = -1.0; /* 060 */ project_value4 = inputadapter_value + 2.2D; /* 061 */ if (false) { /* 062 */ project_arrayData.setNullAt(1); /* 063 */ } else { /* 064 */ project_arrayData.setDouble(1, project_value4); /* 065 */ } /* 066 */ project_holder.reset(); /* 067 */ /* 068 */ // Remember the current cursor so that we can calculate how many bytes are /* 069 */ // written later. /* 070 */ final int project_tmpCursor = project_holder.cursor; /* 071 */ /* 072 */ if (project_arrayData instanceof UnsafeArrayData) { /* 073 */ final int project_sizeInBytes = ((UnsafeArrayData) project_arrayData).getSizeInBytes(); /* 074 */ // grow the global buffer before writing data. /* 075 */ project_holder.grow(project_sizeInBytes); /* 076 */ ((UnsafeArrayData) project_arrayData).writeToMemory(project_holder.buffer, project_holder.cursor); /* 077 */ project_holder.cursor += project_sizeInBytes; /* 078 */ /* 079 */ } else { /* 080 */ final int project_numElements = project_arrayData.numElements(); /* 081 */ project_arrayWriter.initialize(project_holder, project_numElements, 8); /* 082 */ /* 083 */ for (int project_index = 0; project_index < project_numElements; project_index++) { /* 084 */ if (project_arrayData.isNullAt(project_index)) { /* 085 */ project_arrayWriter.setNullDouble(project_index); /* 086 */ } else { /* 087 */ final double project_element = project_arrayData.getDouble(project_index); /* 088 */ project_arrayWriter.write(project_index, project_element); /* 089 */ } /* 090 */ } /* 091 */ } /* 092 */ /* 093 */ project_rowWriter.setOffsetAndSize(0, project_tmpCursor, project_holder.cursor - project_tmpCursor); /* 094 */ project_result.setTotalSize(project_holder.totalSize()); /* 095 */ append(project_result); /* 096 */ if (shouldStop()) return; /* 097 */ } /* 098 */ } /* 099 */ } ``` ## How was this patch tested? Added unit tests into `DataFrameComplexTypeSuite` Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com> Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #13909 from kiszk/SPARK-16213.
Diffstat (limited to 'sql/catalyst/src')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java52
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala174
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala34
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala30
7 files changed, 222 insertions, 89 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
index e8c33871f9..64ab01ca57 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
@@ -287,6 +287,58 @@ public final class UnsafeArrayData extends ArrayData {
return map;
}
+ @Override
+ public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); }
+
+ public void setNullAt(int ordinal) {
+ assertIndexIsValid(ordinal);
+ BitSetMethods.set(baseObject, baseOffset + 8, ordinal);
+
+ /* we assume the corrresponding column was already 0 or
+ will be set to 0 later by the caller side */
+ }
+
+ public void setBoolean(int ordinal, boolean value) {
+ assertIndexIsValid(ordinal);
+ Platform.putBoolean(baseObject, getElementOffset(ordinal, 1), value);
+ }
+
+ public void setByte(int ordinal, byte value) {
+ assertIndexIsValid(ordinal);
+ Platform.putByte(baseObject, getElementOffset(ordinal, 1), value);
+ }
+
+ public void setShort(int ordinal, short value) {
+ assertIndexIsValid(ordinal);
+ Platform.putShort(baseObject, getElementOffset(ordinal, 2), value);
+ }
+
+ public void setInt(int ordinal, int value) {
+ assertIndexIsValid(ordinal);
+ Platform.putInt(baseObject, getElementOffset(ordinal, 4), value);
+ }
+
+ public void setLong(int ordinal, long value) {
+ assertIndexIsValid(ordinal);
+ Platform.putLong(baseObject, getElementOffset(ordinal, 8), value);
+ }
+
+ public void setFloat(int ordinal, float value) {
+ if (Float.isNaN(value)) {
+ value = Float.NaN;
+ }
+ assertIndexIsValid(ordinal);
+ Platform.putFloat(baseObject, getElementOffset(ordinal, 4), value);
+ }
+
+ public void setDouble(int ordinal, double value) {
+ if (Double.isNaN(value)) {
+ value = Double.NaN;
+ }
+ assertIndexIsValid(ordinal);
+ Platform.putDouble(baseObject, getElementOffset(ordinal, 8), value);
+ }
+
// This `hashCode` computation could consume much processor time for large data.
// If the computation becomes a bottleneck, we can use a light-weight logic; the first fixed bytes
// are used to compute `hashCode` (See `Vector.hashCode`).
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 599fb638db..22277ad8d5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
-import org.apache.spark.sql.catalyst.analysis.Star
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.Platform
+import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.UTF8String
/**
@@ -43,7 +44,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array")
- override def dataType: DataType = {
+ override def dataType: ArrayType = {
ArrayType(
children.headOption.map(_.dataType).getOrElse(NullType),
containsNull = children.exists(_.nullable))
@@ -56,33 +57,99 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val arrayClass = classOf[GenericArrayData].getName
- val values = ctx.freshName("values")
- ctx.addMutableState("Object[]", values, s"this.$values = null;")
-
- ev.copy(code = s"""
- this.$values = new Object[${children.size}];""" +
- ctx.splitExpressions(
- ctx.INPUT_ROW,
- children.zipWithIndex.map { case (e, i) =>
- val eval = e.genCode(ctx)
- eval.code + s"""
- if (${eval.isNull}) {
- $values[$i] = null;
- } else {
- $values[$i] = ${eval.value};
- }
- """
- }) +
- s"""
- final ArrayData ${ev.value} = new $arrayClass($values);
- this.$values = null;
- """, isNull = "false")
+ val et = dataType.elementType
+ val evals = children.map(e => e.genCode(ctx))
+ val (preprocess, assigns, postprocess, arrayData) =
+ GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false)
+ ev.copy(
+ code = preprocess + ctx.splitExpressions(ctx.INPUT_ROW, assigns) + postprocess,
+ value = arrayData,
+ isNull = "false")
}
override def prettyName: String = "array"
}
+private [sql] object GenArrayData {
+ /**
+ * Return Java code pieces based on DataType and isPrimitive to allocate ArrayData class
+ *
+ * @param ctx a [[CodegenContext]]
+ * @param elementType data type of underlying array elements
+ * @param elementsCode a set of [[ExprCode]] for each element of an underlying array
+ * @param isMapKey if true, throw an exception when the element is null
+ * @return (code pre-assignments, assignments to each array elements, code post-assignments,
+ * arrayData name)
+ */
+ def genCodeToCreateArrayData(
+ ctx: CodegenContext,
+ elementType: DataType,
+ elementsCode: Seq[ExprCode],
+ isMapKey: Boolean): (String, Seq[String], String, String) = {
+ val arrayName = ctx.freshName("array")
+ val arrayDataName = ctx.freshName("arrayData")
+ val numElements = elementsCode.length
+
+ if (!ctx.isPrimitiveType(elementType)) {
+ val genericArrayClass = classOf[GenericArrayData].getName
+ ctx.addMutableState("Object[]", arrayName,
+ s"this.$arrayName = new Object[${numElements}];")
+
+ val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
+ val isNullAssignment = if (!isMapKey) {
+ s"$arrayName[$i] = null;"
+ } else {
+ "throw new RuntimeException(\"Cannot use null as map key!\");"
+ }
+ eval.code + s"""
+ if (${eval.isNull}) {
+ $isNullAssignment
+ } else {
+ $arrayName[$i] = ${eval.value};
+ }
+ """
+ }
+
+ ("",
+ assignments,
+ s"final ArrayData $arrayDataName = new $genericArrayClass($arrayName);",
+ arrayDataName)
+ } else {
+ val unsafeArraySizeInBytes =
+ UnsafeArrayData.calculateHeaderPortionInBytes(numElements) +
+ ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements)
+ val baseOffset = Platform.BYTE_ARRAY_OFFSET
+ ctx.addMutableState("UnsafeArrayData", arrayDataName, "");
+
+ val primitiveValueTypeName = ctx.primitiveTypeName(elementType)
+ val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
+ val isNullAssignment = if (!isMapKey) {
+ s"$arrayDataName.setNullAt($i);"
+ } else {
+ "throw new RuntimeException(\"Cannot use null as map key!\");"
+ }
+ eval.code + s"""
+ if (${eval.isNull}) {
+ $isNullAssignment
+ } else {
+ $arrayDataName.set$primitiveValueTypeName($i, ${eval.value});
+ }
+ """
+ }
+
+ (s"""
+ byte[] $arrayName = new byte[$unsafeArraySizeInBytes];
+ $arrayDataName = new UnsafeArrayData();
+ Platform.putLong($arrayName, $baseOffset, $numElements);
+ $arrayDataName.pointTo($arrayName, $baseOffset, $unsafeArraySizeInBytes);
+ """,
+ assignments,
+ "",
+ arrayDataName)
+ }
+ }
+}
+
/**
* Returns a catalyst Map containing the evaluation of all children expressions as keys and values.
* The children are a flatted sequence of kv pairs, e.g. (key1, value1, key2, value2, ...)
@@ -133,49 +200,26 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val arrayClass = classOf[GenericArrayData].getName
val mapClass = classOf[ArrayBasedMapData].getName
- val keyArray = ctx.freshName("keyArray")
- val valueArray = ctx.freshName("valueArray")
- ctx.addMutableState("Object[]", keyArray, s"this.$keyArray = null;")
- ctx.addMutableState("Object[]", valueArray, s"this.$valueArray = null;")
-
- val keyData = s"new $arrayClass($keyArray)"
- val valueData = s"new $arrayClass($valueArray)"
- ev.copy(code = s"""
- $keyArray = new Object[${keys.size}];
- $valueArray = new Object[${values.size}];""" +
- ctx.splitExpressions(
- ctx.INPUT_ROW,
- keys.zipWithIndex.map { case (key, i) =>
- val eval = key.genCode(ctx)
- s"""
- ${eval.code}
- if (${eval.isNull}) {
- throw new RuntimeException("Cannot use null as map key!");
- } else {
- $keyArray[$i] = ${eval.value};
- }
- """
- }) +
- ctx.splitExpressions(
- ctx.INPUT_ROW,
- values.zipWithIndex.map { case (value, i) =>
- val eval = value.genCode(ctx)
- s"""
- ${eval.code}
- if (${eval.isNull}) {
- $valueArray[$i] = null;
- } else {
- $valueArray[$i] = ${eval.value};
- }
- """
- }) +
+ val MapType(keyDt, valueDt, _) = dataType
+ val evalKeys = keys.map(e => e.genCode(ctx))
+ val evalValues = values.map(e => e.genCode(ctx))
+ val (preprocessKeyData, assignKeys, postprocessKeyData, keyArrayData) =
+ GenArrayData.genCodeToCreateArrayData(ctx, keyDt, evalKeys, true)
+ val (preprocessValueData, assignValues, postprocessValueData, valueArrayData) =
+ GenArrayData.genCodeToCreateArrayData(ctx, valueDt, evalValues, false)
+ val code =
s"""
- final MapData ${ev.value} = new $mapClass($keyData, $valueData);
- this.$keyArray = null;
- this.$valueArray = null;
- """, isNull = "false")
+ final boolean ${ev.isNull} = false;
+ $preprocessKeyData
+ ${ctx.splitExpressions(ctx.INPUT_ROW, assignKeys)}
+ $postprocessKeyData
+ $preprocessValueData
+ ${ctx.splitExpressions(ctx.INPUT_ROW, assignValues)}
+ $postprocessValueData
+ final MapData ${ev.value} = new $mapClass($keyArrayData, $valueArrayData);
+ """
+ ev.copy(code = code)
}
override def prettyName: String = "map"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
index 140e86d670..9beef41d63 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
@@ -42,6 +42,19 @@ abstract class ArrayData extends SpecializedGetters with Serializable {
def array: Array[Any]
+ def setNullAt(i: Int): Unit
+
+ def update(i: Int, value: Any): Unit
+
+ // default implementation (slow)
+ def setBoolean(i: Int, value: Boolean): Unit = update(i, value)
+ def setByte(i: Int, value: Byte): Unit = update(i, value)
+ def setShort(i: Int, value: Short): Unit = update(i, value)
+ def setInt(i: Int, value: Int): Unit = update(i, value)
+ def setLong(i: Int, value: Long): Unit = update(i, value)
+ def setFloat(i: Int, value: Float): Unit = update(i, value)
+ def setDouble(i: Int, value: Double): Unit = update(i, value)
+
def toBooleanArray(): Array[Boolean] = {
val size = numElements()
val values = new Array[Boolean](size)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
index 7ee9581b63..dd660c80a9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
@@ -71,6 +71,10 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData {
override def getArray(ordinal: Int): ArrayData = getAs(ordinal)
override def getMap(ordinal: Int): MapData = getAs(ordinal)
+ override def setNullAt(ordinal: Int): Unit = array(ordinal) = null
+
+ override def update(ordinal: Int, value: Any): Unit = array(ordinal) = value
+
override def toString(): String = array.mkString("[", ",", "]")
override def equals(o: Any): Boolean = {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index ee5d1f6373..587022f0a2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.objects.{CreateExternalRow, GetExternalRowField, ValidateExternalType}
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.ThreadUtils
@@ -71,7 +71,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType))
val expected = Seq.fill(length)(true)
- if (!checkResult(actual, expected)) {
+ if (actual != expected) {
fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
}
}
@@ -106,9 +106,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
val expressions = Seq(If(EqualTo(strExpr, strExpr), strExpr, strExpr))
val plan = GenerateMutableProjection.generate(expressions)
val actual = plan(null).toSeq(expressions.map(_.dataType))
- val expected = Seq(UTF8String.fromString("abc"))
+ assert(actual.length == 1)
+ val expected = UTF8String.fromString("abc")
- if (!checkResult(actual, expected)) {
+ if (!checkResult(actual.head, expected, expressions.head.dataType)) {
fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
}
}
@@ -118,9 +119,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
val expressions = Seq(CreateArray(List.fill(length)(EqualTo(Literal(1), Literal(1)))))
val plan = GenerateMutableProjection.generate(expressions)
val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType))
- val expected = Seq(new GenericArrayData(Seq.fill(length)(true)))
+ assert(actual.length == 1)
+ val expected = UnsafeArrayData.fromPrimitiveArray(Array.fill(length)(true))
- if (!checkResult(actual, expected)) {
+ if (!checkResult(actual.head, expected, expressions.head.dataType)) {
fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
}
}
@@ -132,12 +134,11 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
case (expr, i) => Seq(Literal(i), expr)
}))
val plan = GenerateMutableProjection.generate(expressions)
- val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)).map {
- case m: ArrayBasedMapData => ArrayBasedMapData.toScalaMap(m)
- }
- val expected = (0 until length).map((_, true)).toMap :: Nil
+ val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType))
+ assert(actual.length == 1)
+ val expected = ArrayBasedMapData((0 until length).toArray, Array.fill(length)(true))
- if (!checkResult(actual, expected)) {
+ if (!checkResult(actual.head, expected, expressions.head.dataType)) {
fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
}
}
@@ -149,7 +150,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType))
val expected = Seq(InternalRow(Seq.fill(length)(true): _*))
- if (!checkResult(actual, expected)) {
+ if (!checkResult(actual, expected, expressions.head.dataType)) {
fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
}
}
@@ -162,9 +163,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
}))
val plan = GenerateMutableProjection.generate(expressions)
val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType))
- val expected = Seq(InternalRow(Seq.fill(length)(true): _*))
+ assert(actual.length == 1)
+ val expected = InternalRow(Seq.fill(length)(true): _*)
- if (!checkResult(actual, expected)) {
+ if (!checkResult(actual.head, expected, expressions.head.dataType)) {
fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
}
}
@@ -177,7 +179,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType))
val expected = Seq(Row.fromSeq(Seq.fill(length)(1)))
- if (!checkResult(actual, expected)) {
+ if (actual != expected) {
fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
}
}
@@ -194,7 +196,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
val expected = Seq.fill(length)(
DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-07-24 07:00:00")))
- if (!checkResult(actual, expected)) {
+ if (actual != expected) {
fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index c21c6de32c..abe1d2b2c9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -120,16 +120,20 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
test("CreateArray") {
val intSeq = Seq(5, 10, 15, 20, 25)
val longSeq = intSeq.map(_.toLong)
+ val byteSeq = intSeq.map(_.toByte)
val strSeq = intSeq.map(_.toString)
checkEvaluation(CreateArray(intSeq.map(Literal(_))), intSeq, EmptyRow)
checkEvaluation(CreateArray(longSeq.map(Literal(_))), longSeq, EmptyRow)
+ checkEvaluation(CreateArray(byteSeq.map(Literal(_))), byteSeq, EmptyRow)
checkEvaluation(CreateArray(strSeq.map(Literal(_))), strSeq, EmptyRow)
val intWithNull = intSeq.map(Literal(_)) :+ Literal.create(null, IntegerType)
val longWithNull = longSeq.map(Literal(_)) :+ Literal.create(null, LongType)
+ val byteWithNull = byteSeq.map(Literal(_)) :+ Literal.create(null, ByteType)
val strWithNull = strSeq.map(Literal(_)) :+ Literal.create(null, StringType)
checkEvaluation(CreateArray(intWithNull), intSeq :+ null, EmptyRow)
checkEvaluation(CreateArray(longWithNull), longSeq :+ null, EmptyRow)
+ checkEvaluation(CreateArray(byteWithNull), byteSeq :+ null, EmptyRow)
checkEvaluation(CreateArray(strWithNull), strSeq :+ null, EmptyRow)
checkEvaluation(CreateArray(Literal.create(null, IntegerType) :: Nil), null :: Nil)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index f83650424a..1ba6dd1c5e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -28,8 +28,8 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
-import org.apache.spark.sql.catalyst.util.MapData
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
+import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
/**
@@ -59,14 +59,28 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
* Check the equality between result of expression and expected value, it will handle
* Array[Byte], Spread[Double], and MapData.
*/
- protected def checkResult(result: Any, expected: Any): Boolean = {
+ protected def checkResult(result: Any, expected: Any, dataType: DataType): Boolean = {
(result, expected) match {
case (result: Array[Byte], expected: Array[Byte]) =>
java.util.Arrays.equals(result, expected)
case (result: Double, expected: Spread[Double @unchecked]) =>
expected.asInstanceOf[Spread[Double]].isWithin(result)
+ case (result: ArrayData, expected: ArrayData) =>
+ result.numElements == expected.numElements && {
+ val et = dataType.asInstanceOf[ArrayType].elementType
+ var isSame = true
+ var i = 0
+ while (isSame && i < result.numElements) {
+ isSame = checkResult(result.get(i, et), expected.get(i, et), et)
+ i += 1
+ }
+ isSame
+ }
case (result: MapData, expected: MapData) =>
- result.keyArray() == expected.keyArray() && result.valueArray() == expected.valueArray()
+ val kt = dataType.asInstanceOf[MapType].keyType
+ val vt = dataType.asInstanceOf[MapType].valueType
+ checkResult(result.keyArray, expected.keyArray, ArrayType(kt)) &&
+ checkResult(result.valueArray, expected.valueArray, ArrayType(vt))
case (result: Double, expected: Double) =>
if (expected.isNaN) result.isNaN else expected == result
case (result: Float, expected: Float) =>
@@ -108,7 +122,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
val actual = try evaluate(expression, inputRow) catch {
case e: Exception => fail(s"Exception evaluating $expression", e)
}
- if (!checkResult(actual, expected)) {
+ if (!checkResult(actual, expected, expression.dataType)) {
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
fail(s"Incorrect evaluation (codegen off): $expression, " +
s"actual: $actual, " +
@@ -127,7 +141,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
plan.initialize(0)
val actual = plan(inputRow).get(0, expression.dataType)
- if (!checkResult(actual, expected)) {
+ if (!checkResult(actual, expected, expression.dataType)) {
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
fail(s"Incorrect evaluation: $expression, actual: $actual, expected: $expected$input")
}
@@ -188,7 +202,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
expression)
plan.initialize(0)
var actual = plan(inputRow).get(0, expression.dataType)
- assert(checkResult(actual, expected))
+ assert(checkResult(actual, expected, expression.dataType))
plan = generateProject(
GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil),
@@ -196,7 +210,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
plan.initialize(0)
actual = FromUnsafeProjection(expression.dataType :: Nil)(
plan(inputRow)).get(0, expression.dataType)
- assert(checkResult(actual, expected))
+ assert(checkResult(actual, expected, expression.dataType))
}
/**