diff options
9 files changed, 230 insertions, 91 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index e8c33871f9..64ab01ca57 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -287,6 +287,58 @@ public final class UnsafeArrayData extends ArrayData { return map; } + @Override + public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); } + + public void setNullAt(int ordinal) { + assertIndexIsValid(ordinal); + BitSetMethods.set(baseObject, baseOffset + 8, ordinal); + + /* we assume the corrresponding column was already 0 or + will be set to 0 later by the caller side */ + } + + public void setBoolean(int ordinal, boolean value) { + assertIndexIsValid(ordinal); + Platform.putBoolean(baseObject, getElementOffset(ordinal, 1), value); + } + + public void setByte(int ordinal, byte value) { + assertIndexIsValid(ordinal); + Platform.putByte(baseObject, getElementOffset(ordinal, 1), value); + } + + public void setShort(int ordinal, short value) { + assertIndexIsValid(ordinal); + Platform.putShort(baseObject, getElementOffset(ordinal, 2), value); + } + + public void setInt(int ordinal, int value) { + assertIndexIsValid(ordinal); + Platform.putInt(baseObject, getElementOffset(ordinal, 4), value); + } + + public void setLong(int ordinal, long value) { + assertIndexIsValid(ordinal); + Platform.putLong(baseObject, getElementOffset(ordinal, 8), value); + } + + public void setFloat(int ordinal, float value) { + if (Float.isNaN(value)) { + value = Float.NaN; + } + assertIndexIsValid(ordinal); + Platform.putFloat(baseObject, getElementOffset(ordinal, 4), value); + } + + public void setDouble(int ordinal, double value) { + if (Double.isNaN(value)) { + value = Double.NaN; + } + assertIndexIsValid(ordinal); + Platform.putDouble(baseObject, getElementOffset(ordinal, 8), value); + } + // This `hashCode` computation could consume much processor time for large data. // If the computation becomes a bottleneck, we can use a light-weight logic; the first fixed bytes // are used to compute `hashCode` (See `Vector.hashCode`). diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 599fb638db..22277ad8d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.analysis.Star import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.UTF8String /** @@ -43,7 +44,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array") - override def dataType: DataType = { + override def dataType: ArrayType = { ArrayType( children.headOption.map(_.dataType).getOrElse(NullType), containsNull = children.exists(_.nullable)) @@ -56,33 +57,99 @@ case class CreateArray(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val arrayClass = classOf[GenericArrayData].getName - val values = ctx.freshName("values") - ctx.addMutableState("Object[]", values, s"this.$values = null;") - - ev.copy(code = s""" - this.$values = new Object[${children.size}];""" + - ctx.splitExpressions( - ctx.INPUT_ROW, - children.zipWithIndex.map { case (e, i) => - val eval = e.genCode(ctx) - eval.code + s""" - if (${eval.isNull}) { - $values[$i] = null; - } else { - $values[$i] = ${eval.value}; - } - """ - }) + - s""" - final ArrayData ${ev.value} = new $arrayClass($values); - this.$values = null; - """, isNull = "false") + val et = dataType.elementType + val evals = children.map(e => e.genCode(ctx)) + val (preprocess, assigns, postprocess, arrayData) = + GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) + ev.copy( + code = preprocess + ctx.splitExpressions(ctx.INPUT_ROW, assigns) + postprocess, + value = arrayData, + isNull = "false") } override def prettyName: String = "array" } +private [sql] object GenArrayData { + /** + * Return Java code pieces based on DataType and isPrimitive to allocate ArrayData class + * + * @param ctx a [[CodegenContext]] + * @param elementType data type of underlying array elements + * @param elementsCode a set of [[ExprCode]] for each element of an underlying array + * @param isMapKey if true, throw an exception when the element is null + * @return (code pre-assignments, assignments to each array elements, code post-assignments, + * arrayData name) + */ + def genCodeToCreateArrayData( + ctx: CodegenContext, + elementType: DataType, + elementsCode: Seq[ExprCode], + isMapKey: Boolean): (String, Seq[String], String, String) = { + val arrayName = ctx.freshName("array") + val arrayDataName = ctx.freshName("arrayData") + val numElements = elementsCode.length + + if (!ctx.isPrimitiveType(elementType)) { + val genericArrayClass = classOf[GenericArrayData].getName + ctx.addMutableState("Object[]", arrayName, + s"this.$arrayName = new Object[${numElements}];") + + val assignments = elementsCode.zipWithIndex.map { case (eval, i) => + val isNullAssignment = if (!isMapKey) { + s"$arrayName[$i] = null;" + } else { + "throw new RuntimeException(\"Cannot use null as map key!\");" + } + eval.code + s""" + if (${eval.isNull}) { + $isNullAssignment + } else { + $arrayName[$i] = ${eval.value}; + } + """ + } + + ("", + assignments, + s"final ArrayData $arrayDataName = new $genericArrayClass($arrayName);", + arrayDataName) + } else { + val unsafeArraySizeInBytes = + UnsafeArrayData.calculateHeaderPortionInBytes(numElements) + + ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements) + val baseOffset = Platform.BYTE_ARRAY_OFFSET + ctx.addMutableState("UnsafeArrayData", arrayDataName, ""); + + val primitiveValueTypeName = ctx.primitiveTypeName(elementType) + val assignments = elementsCode.zipWithIndex.map { case (eval, i) => + val isNullAssignment = if (!isMapKey) { + s"$arrayDataName.setNullAt($i);" + } else { + "throw new RuntimeException(\"Cannot use null as map key!\");" + } + eval.code + s""" + if (${eval.isNull}) { + $isNullAssignment + } else { + $arrayDataName.set$primitiveValueTypeName($i, ${eval.value}); + } + """ + } + + (s""" + byte[] $arrayName = new byte[$unsafeArraySizeInBytes]; + $arrayDataName = new UnsafeArrayData(); + Platform.putLong($arrayName, $baseOffset, $numElements); + $arrayDataName.pointTo($arrayName, $baseOffset, $unsafeArraySizeInBytes); + """, + assignments, + "", + arrayDataName) + } + } +} + /** * Returns a catalyst Map containing the evaluation of all children expressions as keys and values. * The children are a flatted sequence of kv pairs, e.g. (key1, value1, key2, value2, ...) @@ -133,49 +200,26 @@ case class CreateMap(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val arrayClass = classOf[GenericArrayData].getName val mapClass = classOf[ArrayBasedMapData].getName - val keyArray = ctx.freshName("keyArray") - val valueArray = ctx.freshName("valueArray") - ctx.addMutableState("Object[]", keyArray, s"this.$keyArray = null;") - ctx.addMutableState("Object[]", valueArray, s"this.$valueArray = null;") - - val keyData = s"new $arrayClass($keyArray)" - val valueData = s"new $arrayClass($valueArray)" - ev.copy(code = s""" - $keyArray = new Object[${keys.size}]; - $valueArray = new Object[${values.size}];""" + - ctx.splitExpressions( - ctx.INPUT_ROW, - keys.zipWithIndex.map { case (key, i) => - val eval = key.genCode(ctx) - s""" - ${eval.code} - if (${eval.isNull}) { - throw new RuntimeException("Cannot use null as map key!"); - } else { - $keyArray[$i] = ${eval.value}; - } - """ - }) + - ctx.splitExpressions( - ctx.INPUT_ROW, - values.zipWithIndex.map { case (value, i) => - val eval = value.genCode(ctx) - s""" - ${eval.code} - if (${eval.isNull}) { - $valueArray[$i] = null; - } else { - $valueArray[$i] = ${eval.value}; - } - """ - }) + + val MapType(keyDt, valueDt, _) = dataType + val evalKeys = keys.map(e => e.genCode(ctx)) + val evalValues = values.map(e => e.genCode(ctx)) + val (preprocessKeyData, assignKeys, postprocessKeyData, keyArrayData) = + GenArrayData.genCodeToCreateArrayData(ctx, keyDt, evalKeys, true) + val (preprocessValueData, assignValues, postprocessValueData, valueArrayData) = + GenArrayData.genCodeToCreateArrayData(ctx, valueDt, evalValues, false) + val code = s""" - final MapData ${ev.value} = new $mapClass($keyData, $valueData); - this.$keyArray = null; - this.$valueArray = null; - """, isNull = "false") + final boolean ${ev.isNull} = false; + $preprocessKeyData + ${ctx.splitExpressions(ctx.INPUT_ROW, assignKeys)} + $postprocessKeyData + $preprocessValueData + ${ctx.splitExpressions(ctx.INPUT_ROW, assignValues)} + $postprocessValueData + final MapData ${ev.value} = new $mapClass($keyArrayData, $valueArrayData); + """ + ev.copy(code = code) } override def prettyName: String = "map" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala index 140e86d670..9beef41d63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala @@ -42,6 +42,19 @@ abstract class ArrayData extends SpecializedGetters with Serializable { def array: Array[Any] + def setNullAt(i: Int): Unit + + def update(i: Int, value: Any): Unit + + // default implementation (slow) + def setBoolean(i: Int, value: Boolean): Unit = update(i, value) + def setByte(i: Int, value: Byte): Unit = update(i, value) + def setShort(i: Int, value: Short): Unit = update(i, value) + def setInt(i: Int, value: Int): Unit = update(i, value) + def setLong(i: Int, value: Long): Unit = update(i, value) + def setFloat(i: Int, value: Float): Unit = update(i, value) + def setDouble(i: Int, value: Double): Unit = update(i, value) + def toBooleanArray(): Array[Boolean] = { val size = numElements() val values = new Array[Boolean](size) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index 7ee9581b63..dd660c80a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -71,6 +71,10 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData { override def getArray(ordinal: Int): ArrayData = getAs(ordinal) override def getMap(ordinal: Int): MapData = getAs(ordinal) + override def setNullAt(ordinal: Int): Unit = array(ordinal) = null + + override def update(ordinal: Int, value: Any): Unit = array(ordinal) = value + override def toString(): String = array.mkString("[", ",", "]") override def equals(o: Any): Boolean = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index ee5d1f6373..587022f0a2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.objects.{CreateExternalRow, GetExternalRowField, ValidateExternalType} -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ThreadUtils @@ -71,7 +71,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq.fill(length)(true) - if (!checkResult(actual, expected)) { + if (actual != expected) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -106,9 +106,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val expressions = Seq(If(EqualTo(strExpr, strExpr), strExpr, strExpr)) val plan = GenerateMutableProjection.generate(expressions) val actual = plan(null).toSeq(expressions.map(_.dataType)) - val expected = Seq(UTF8String.fromString("abc")) + assert(actual.length == 1) + val expected = UTF8String.fromString("abc") - if (!checkResult(actual, expected)) { + if (!checkResult(actual.head, expected, expressions.head.dataType)) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -118,9 +119,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val expressions = Seq(CreateArray(List.fill(length)(EqualTo(Literal(1), Literal(1))))) val plan = GenerateMutableProjection.generate(expressions) val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) - val expected = Seq(new GenericArrayData(Seq.fill(length)(true))) + assert(actual.length == 1) + val expected = UnsafeArrayData.fromPrimitiveArray(Array.fill(length)(true)) - if (!checkResult(actual, expected)) { + if (!checkResult(actual.head, expected, expressions.head.dataType)) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -132,12 +134,11 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { case (expr, i) => Seq(Literal(i), expr) })) val plan = GenerateMutableProjection.generate(expressions) - val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)).map { - case m: ArrayBasedMapData => ArrayBasedMapData.toScalaMap(m) - } - val expected = (0 until length).map((_, true)).toMap :: Nil + val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) + assert(actual.length == 1) + val expected = ArrayBasedMapData((0 until length).toArray, Array.fill(length)(true)) - if (!checkResult(actual, expected)) { + if (!checkResult(actual.head, expected, expressions.head.dataType)) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -149,7 +150,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq(InternalRow(Seq.fill(length)(true): _*)) - if (!checkResult(actual, expected)) { + if (!checkResult(actual, expected, expressions.head.dataType)) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -162,9 +163,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { })) val plan = GenerateMutableProjection.generate(expressions) val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) - val expected = Seq(InternalRow(Seq.fill(length)(true): _*)) + assert(actual.length == 1) + val expected = InternalRow(Seq.fill(length)(true): _*) - if (!checkResult(actual, expected)) { + if (!checkResult(actual.head, expected, expressions.head.dataType)) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -177,7 +179,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq(Row.fromSeq(Seq.fill(length)(1))) - if (!checkResult(actual, expected)) { + if (actual != expected) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -194,7 +196,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val expected = Seq.fill(length)( DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-07-24 07:00:00"))) - if (!checkResult(actual, expected)) { + if (actual != expected) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index c21c6de32c..abe1d2b2c9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -120,16 +120,20 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { test("CreateArray") { val intSeq = Seq(5, 10, 15, 20, 25) val longSeq = intSeq.map(_.toLong) + val byteSeq = intSeq.map(_.toByte) val strSeq = intSeq.map(_.toString) checkEvaluation(CreateArray(intSeq.map(Literal(_))), intSeq, EmptyRow) checkEvaluation(CreateArray(longSeq.map(Literal(_))), longSeq, EmptyRow) + checkEvaluation(CreateArray(byteSeq.map(Literal(_))), byteSeq, EmptyRow) checkEvaluation(CreateArray(strSeq.map(Literal(_))), strSeq, EmptyRow) val intWithNull = intSeq.map(Literal(_)) :+ Literal.create(null, IntegerType) val longWithNull = longSeq.map(Literal(_)) :+ Literal.create(null, LongType) + val byteWithNull = byteSeq.map(Literal(_)) :+ Literal.create(null, ByteType) val strWithNull = strSeq.map(Literal(_)) :+ Literal.create(null, StringType) checkEvaluation(CreateArray(intWithNull), intSeq :+ null, EmptyRow) checkEvaluation(CreateArray(longWithNull), longSeq :+ null, EmptyRow) + checkEvaluation(CreateArray(byteWithNull), byteSeq :+ null, EmptyRow) checkEvaluation(CreateArray(strWithNull), strSeq :+ null, EmptyRow) checkEvaluation(CreateArray(Literal.create(null, IntegerType) :: Nil), null :: Nil) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index f83650424a..1ba6dd1c5e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -28,8 +28,8 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} -import org.apache.spark.sql.catalyst.util.MapData -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils /** @@ -59,14 +59,28 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { * Check the equality between result of expression and expected value, it will handle * Array[Byte], Spread[Double], and MapData. */ - protected def checkResult(result: Any, expected: Any): Boolean = { + protected def checkResult(result: Any, expected: Any, dataType: DataType): Boolean = { (result, expected) match { case (result: Array[Byte], expected: Array[Byte]) => java.util.Arrays.equals(result, expected) case (result: Double, expected: Spread[Double @unchecked]) => expected.asInstanceOf[Spread[Double]].isWithin(result) + case (result: ArrayData, expected: ArrayData) => + result.numElements == expected.numElements && { + val et = dataType.asInstanceOf[ArrayType].elementType + var isSame = true + var i = 0 + while (isSame && i < result.numElements) { + isSame = checkResult(result.get(i, et), expected.get(i, et), et) + i += 1 + } + isSame + } case (result: MapData, expected: MapData) => - result.keyArray() == expected.keyArray() && result.valueArray() == expected.valueArray() + val kt = dataType.asInstanceOf[MapType].keyType + val vt = dataType.asInstanceOf[MapType].valueType + checkResult(result.keyArray, expected.keyArray, ArrayType(kt)) && + checkResult(result.valueArray, expected.valueArray, ArrayType(vt)) case (result: Double, expected: Double) => if (expected.isNaN) result.isNaN else expected == result case (result: Float, expected: Float) => @@ -108,7 +122,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } - if (!checkResult(actual, expected)) { + if (!checkResult(actual, expected, expression.dataType)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect evaluation (codegen off): $expression, " + s"actual: $actual, " + @@ -127,7 +141,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { plan.initialize(0) val actual = plan(inputRow).get(0, expression.dataType) - if (!checkResult(actual, expected)) { + if (!checkResult(actual, expected, expression.dataType)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect evaluation: $expression, actual: $actual, expected: $expected$input") } @@ -188,7 +202,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { expression) plan.initialize(0) var actual = plan(inputRow).get(0, expression.dataType) - assert(checkResult(actual, expected)) + assert(checkResult(actual, expected, expression.dataType)) plan = generateProject( GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), @@ -196,7 +210,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { plan.initialize(0) actual = FromUnsafeProjection(expression.dataType :: Nil)( plan(inputRow)).get(0, expression.dataType) - assert(checkResult(actual, expected)) + assert(checkResult(actual, expected, expression.dataType)) } /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index ff07940422..354c878aca 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -246,6 +246,12 @@ public abstract class ColumnVector implements AutoCloseable { public Object get(int ordinal, DataType dataType) { throw new UnsupportedOperationException(); } + + @Override + public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); } + + @Override + public void setNullAt(int ordinal) { throw new UnsupportedOperationException(); } } /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala index 9a8d4498bb..9eaf44c043 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala @@ -411,8 +411,8 @@ class ObjectHashAggregateSuite actual.zip(expected).foreach { case (lhs: Row, rhs: Row) => assert(lhs.length == rhs.length) lhs.toSeq.zip(rhs.toSeq).foreach { - case (a: Double, b: Double) => checkResult(a, b +- tolerance) - case (a, b) => checkResult(a, b) + case (a: Double, b: Double) => checkResult(a, b +- tolerance, DoubleType) + case (a, b) => a == b } } } |