From 6e009cb9c4d7a395991e10dab427f37019283758 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 6 Aug 2015 10:40:54 -0700 Subject: [SPARK-9632][SQL] update InternalRow.toSeq to make it accept data type info Author: Wenchen Fan Closes #7955 from cloud-fan/toSeq and squashes the following commits: 21665e2 [Wenchen Fan] fix hive again... 4addf29 [Wenchen Fan] fix hive bc16c59 [Wenchen Fan] minor fix 33d802c [Wenchen Fan] pass data type info to InternalRow.toSeq 3dd033e [Wenchen Fan] move the default special getters implementation from InternalRow to BaseGenericInternalRow --- .../apache/spark/sql/catalyst/InternalRow.scala | 132 ++------------------- .../sql/catalyst/expressions/Projection.scala | 12 +- .../catalyst/expressions/SpecificMutableRow.scala | 5 +- .../expressions/codegen/GenerateProjection.scala | 8 +- .../spark/sql/catalyst/expressions/rows.scala | 132 ++++++++++++++++++++- .../catalyst/expressions/CodeGenerationSuite.scala | 2 +- 6 files changed, 154 insertions(+), 137 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 7d17cca808..85b4bf3b6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{DataType, MapData, ArrayData, Decimal} -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.sql.types.{DataType, StructType} /** * An abstract class for row used internal in Spark SQL, which only contain the columns as @@ -32,8 +31,6 @@ abstract class InternalRow extends SpecializedGetters with Serializable { // This is only use for test and will throw a null pointer exception if the position is null. def getString(ordinal: Int): String = getUTF8String(ordinal).toString - override def toString: String = mkString("[", ",", "]") - /** * Make a copy of the current [[InternalRow]] object. */ @@ -50,136 +47,25 @@ abstract class InternalRow extends SpecializedGetters with Serializable { false } - // Subclasses of InternalRow should implement all special getters and equals/hashCode, - // or implement this genericGet. - protected def genericGet(ordinal: Int): Any = throw new IllegalStateException( - "Concrete internal rows should implement genericGet, " + - "or implement all special getters and equals/hashCode") - - // default implementation (slow) - private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T] - override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null - override def get(ordinal: Int, dataType: DataType): AnyRef = getAs(ordinal) - override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) - override def getByte(ordinal: Int): Byte = getAs(ordinal) - override def getShort(ordinal: Int): Short = getAs(ordinal) - override def getInt(ordinal: Int): Int = getAs(ordinal) - override def getLong(ordinal: Int): Long = getAs(ordinal) - override def getFloat(ordinal: Int): Float = getAs(ordinal) - override def getDouble(ordinal: Int): Double = getAs(ordinal) - override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) - override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) - override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) - override def getArray(ordinal: Int): ArrayData = getAs(ordinal) - override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) - override def getMap(ordinal: Int): MapData = getAs(ordinal) - override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) - - override def equals(o: Any): Boolean = { - if (!o.isInstanceOf[InternalRow]) { - return false - } - - val other = o.asInstanceOf[InternalRow] - if (other eq null) { - return false - } - - val len = numFields - if (len != other.numFields) { - return false - } - - var i = 0 - while (i < len) { - if (isNullAt(i) != other.isNullAt(i)) { - return false - } - if (!isNullAt(i)) { - val o1 = genericGet(i) - val o2 = other.genericGet(i) - o1 match { - case b1: Array[Byte] => - if (!o2.isInstanceOf[Array[Byte]] || - !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { - return false - } - case f1: Float if java.lang.Float.isNaN(f1) => - if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { - return false - } - case d1: Double if java.lang.Double.isNaN(d1) => - if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { - return false - } - case _ => if (o1 != o2) { - return false - } - } - } - i += 1 - } - true - } - - // Custom hashCode function that matches the efficient code generated version. - override def hashCode: Int = { - var result: Int = 37 - var i = 0 - val len = numFields - while (i < len) { - val update: Int = - if (isNullAt(i)) { - 0 - } else { - genericGet(i) match { - case b: Boolean => if (b) 0 else 1 - case b: Byte => b.toInt - case s: Short => s.toInt - case i: Int => i - case l: Long => (l ^ (l >>> 32)).toInt - case f: Float => java.lang.Float.floatToIntBits(f) - case d: Double => - val b = java.lang.Double.doubleToLongBits(d) - (b ^ (b >>> 32)).toInt - case a: Array[Byte] => java.util.Arrays.hashCode(a) - case other => other.hashCode() - } - } - result = 37 * result + update - i += 1 - } - result - } - /* ---------------------- utility methods for Scala ---------------------- */ /** * Return a Scala Seq representing the row. Elements are placed in the same order in the Seq. */ - // todo: remove this as it needs the generic getter - def toSeq: Seq[Any] = { - val n = numFields - val values = new Array[Any](n) + def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = { + val len = numFields + assert(len == fieldTypes.length) + + val values = new Array[Any](len) var i = 0 - while (i < n) { - values.update(i, genericGet(i)) + while (i < len) { + values(i) = get(i, fieldTypes(i)) i += 1 } values } - /** Displays all elements of this sequence in a string (without a separator). */ - def mkString: String = toSeq.mkString - - /** Displays all elements of this sequence in a string using a separator string. */ - def mkString(sep: String): String = toSeq.mkString(sep) - - /** - * Displays all elements of this traversable or iterator in a string using - * start, end, and separator strings. - */ - def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) + def toSeq(schema: StructType): Seq[Any] = toSeq(schema.map(_.dataType)) } object InternalRow { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 4296b4b123..59ce7fc4f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -203,7 +203,11 @@ class JoinedRow extends InternalRow { this } - override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq + override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = { + assert(fieldTypes.length == row1.numFields + row2.numFields) + val (left, right) = fieldTypes.splitAt(row1.numFields) + row1.toSeq(left) ++ row2.toSeq(right) + } override def numFields: Int = row1.numFields + row2.numFields @@ -276,11 +280,11 @@ class JoinedRow extends InternalRow { if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" } else if (row1 eq null) { - row2.mkString("[", ",", "]") + row2.toString } else if (row2 eq null) { - row1.mkString("[", ",", "]") + row1.toString } else { - mkString("[", ",", "]") + s"{${row1.toString} + ${row2.toString}}" } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index b94df6bd66..4f56f94bd4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -192,7 +192,8 @@ final class MutableAny extends MutableValue { * based on the dataTypes of each column. The intent is to decrease garbage when modifying the * values of primitive columns. */ -final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableRow { +final class SpecificMutableRow(val values: Array[MutableValue]) + extends MutableRow with BaseGenericInternalRow { def this(dataTypes: Seq[DataType]) = this( @@ -213,8 +214,6 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def numFields: Int = values.length - override def toSeq: Seq[Any] = values.map(_.boxed) - override def setNullAt(i: Int): Unit = { values(i).isNull = true } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index c04fe734d5..c744e84d82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -25,6 +26,8 @@ import org.apache.spark.sql.types._ */ abstract class BaseProjection extends Projection {} +abstract class CodeGenMutableRow extends MutableRow with BaseGenericInternalRow + /** * Generates bytecode that produces a new [[InternalRow]] object based on a fixed set of input * [[Expression Expressions]] and a given input [[InternalRow]]. The returned [[InternalRow]] @@ -171,7 +174,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { return new SpecificRow((InternalRow) r); } - final class SpecificRow extends ${classOf[MutableRow].getName} { + final class SpecificRow extends ${classOf[CodeGenMutableRow].getName} { $columns @@ -184,7 +187,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { public void setNullAt(int i) { nullBits[i] = true; } public boolean isNullAt(int i) { return nullBits[i]; } - protected Object genericGet(int i) { + @Override + public Object genericGet(int i) { if (isNullAt(i)) return null; switch (i) { $getCases diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 7657fb535d..207e667792 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -21,6 +21,130 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ +/** + * An extended version of [[InternalRow]] that implements all special getters, toString + * and equals/hashCode by `genericGet`. + */ +trait BaseGenericInternalRow extends InternalRow { + + protected def genericGet(ordinal: Int): Any + + // default implementation (slow) + private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T] + override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null + override def get(ordinal: Int, dataType: DataType): AnyRef = getAs(ordinal) + override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) + override def getByte(ordinal: Int): Byte = getAs(ordinal) + override def getShort(ordinal: Int): Short = getAs(ordinal) + override def getInt(ordinal: Int): Int = getAs(ordinal) + override def getLong(ordinal: Int): Long = getAs(ordinal) + override def getFloat(ordinal: Int): Float = getAs(ordinal) + override def getDouble(ordinal: Int): Double = getAs(ordinal) + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) + override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) + override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) + override def getArray(ordinal: Int): ArrayData = getAs(ordinal) + override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) + override def getMap(ordinal: Int): MapData = getAs(ordinal) + override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) + + override def toString(): String = { + if (numFields == 0) { + "[empty row]" + } else { + val sb = new StringBuilder + sb.append("[") + sb.append(genericGet(0)) + val len = numFields + var i = 1 + while (i < len) { + sb.append(",") + sb.append(genericGet(i)) + i += 1 + } + sb.append("]") + sb.toString() + } + } + + override def equals(o: Any): Boolean = { + if (!o.isInstanceOf[BaseGenericInternalRow]) { + return false + } + + val other = o.asInstanceOf[BaseGenericInternalRow] + if (other eq null) { + return false + } + + val len = numFields + if (len != other.numFields) { + return false + } + + var i = 0 + while (i < len) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = genericGet(i) + val o2 = other.genericGet(i) + o1 match { + case b1: Array[Byte] => + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + case f1: Float if java.lang.Float.isNaN(f1) => + if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { + return false + } + case d1: Double if java.lang.Double.isNaN(d1) => + if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { + return false + } + case _ => if (o1 != o2) { + return false + } + } + } + i += 1 + } + true + } + + // Custom hashCode function that matches the efficient code generated version. + override def hashCode: Int = { + var result: Int = 37 + var i = 0 + val len = numFields + while (i < len) { + val update: Int = + if (isNullAt(i)) { + 0 + } else { + genericGet(i) match { + case b: Boolean => if (b) 0 else 1 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case a: Array[Byte] => java.util.Arrays.hashCode(a) + case other => other.hashCode() + } + } + result = 37 * result + update + i += 1 + } + result + } +} + /** * An extended interface to [[InternalRow]] that allows the values for each column to be updated. * Setting a value through a primitive function implicitly marks that column as not null. @@ -82,7 +206,7 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType) * Note that, while the array is not copied, and thus could technically be mutated after creation, * this is not allowed. */ -class GenericInternalRow(private[sql] val values: Array[Any]) extends InternalRow { +class GenericInternalRow(private[sql] val values: Array[Any]) extends BaseGenericInternalRow { /** No-arg constructor for serialization. */ protected def this() = this(null) @@ -90,7 +214,7 @@ class GenericInternalRow(private[sql] val values: Array[Any]) extends InternalRo override protected def genericGet(ordinal: Int) = values(ordinal) - override def toSeq: Seq[Any] = values + override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values override def numFields: Int = values.length @@ -109,7 +233,7 @@ class GenericInternalRowWithSchema(values: Array[Any], val schema: StructType) def fieldIndex(name: String): Int = schema.fieldIndex(name) } -class GenericMutableRow(values: Array[Any]) extends MutableRow { +class GenericMutableRow(values: Array[Any]) extends MutableRow with BaseGenericInternalRow { /** No-arg constructor for serialization. */ protected def this() = this(null) @@ -117,7 +241,7 @@ class GenericMutableRow(values: Array[Any]) extends MutableRow { override protected def genericGet(ordinal: Int) = values(ordinal) - override def toSeq: Seq[Any] = values + override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values override def numFields: Int = values.length diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index e310aee221..e323467af5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -87,7 +87,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val length = 5000 val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1))) val plan = GenerateMutableProjection.generate(expressions)() - val actual = plan(new GenericMutableRow(length)).toSeq + val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq.fill(length)(true) if (!checkResult(actual, expected)) { -- cgit v1.2.3