aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-08-06 13:11:59 -0700
committerReynold Xin <rxin@databricks.com>2015-08-06 13:11:59 -0700
commit1f62f104c7a2aeac625b17d9e5ac62f1f10a2b21 (patch)
treef04d74dffd581fa1eeb8e7a1f929f2aa843cf0a0 /sql/catalyst
parenta1bbf1bc5c51cd796015ac159799cf024de6fa07 (diff)
downloadspark-1f62f104c7a2aeac625b17d9e5ac62f1f10a2b21.tar.gz
spark-1f62f104c7a2aeac625b17d9e5ac62f1f10a2b21.tar.bz2
spark-1f62f104c7a2aeac625b17d9e5ac62f1f10a2b21.zip
[SPARK-9632][SQL] update InternalRow.toSeq to make it accept data type info
This re-applies #7955, which was reverted due to a race condition to fix build breaking. Author: Wenchen Fan <cloud0fan@outlook.com> Author: Reynold Xin <rxin@databricks.com> Closes #8002 from rxin/InternalRow-toSeq and squashes the following commits: 332416a [Reynold Xin] Merge pull request #7955 from cloud-fan/toSeq 21665e2 [Wenchen Fan] fix hive again... 4addf29 [Wenchen Fan] fix hive bc16c59 [Wenchen Fan] minor fix 33d802c [Wenchen Fan] pass data type info to InternalRow.toSeq 3dd033e [Wenchen Fan] move the default special getters implementation from InternalRow to BaseGenericInternalRow
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala132
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala132
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala2
6 files changed, 154 insertions, 137 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
index 7d17cca808..85b4bf3b6a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
@@ -18,8 +18,7 @@
package org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types.{DataType, MapData, ArrayData, Decimal}
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.sql.types.{DataType, StructType}
/**
* An abstract class for row used internal in Spark SQL, which only contain the columns as
@@ -32,8 +31,6 @@ abstract class InternalRow extends SpecializedGetters with Serializable {
// This is only use for test and will throw a null pointer exception if the position is null.
def getString(ordinal: Int): String = getUTF8String(ordinal).toString
- override def toString: String = mkString("[", ",", "]")
-
/**
* Make a copy of the current [[InternalRow]] object.
*/
@@ -50,136 +47,25 @@ abstract class InternalRow extends SpecializedGetters with Serializable {
false
}
- // Subclasses of InternalRow should implement all special getters and equals/hashCode,
- // or implement this genericGet.
- protected def genericGet(ordinal: Int): Any = throw new IllegalStateException(
- "Concrete internal rows should implement genericGet, " +
- "or implement all special getters and equals/hashCode")
-
- // default implementation (slow)
- private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T]
- override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null
- override def get(ordinal: Int, dataType: DataType): AnyRef = getAs(ordinal)
- override def getBoolean(ordinal: Int): Boolean = getAs(ordinal)
- override def getByte(ordinal: Int): Byte = getAs(ordinal)
- override def getShort(ordinal: Int): Short = getAs(ordinal)
- override def getInt(ordinal: Int): Int = getAs(ordinal)
- override def getLong(ordinal: Int): Long = getAs(ordinal)
- override def getFloat(ordinal: Int): Float = getAs(ordinal)
- override def getDouble(ordinal: Int): Double = getAs(ordinal)
- override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal)
- override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal)
- override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal)
- override def getArray(ordinal: Int): ArrayData = getAs(ordinal)
- override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal)
- override def getMap(ordinal: Int): MapData = getAs(ordinal)
- override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal)
-
- override def equals(o: Any): Boolean = {
- if (!o.isInstanceOf[InternalRow]) {
- return false
- }
-
- val other = o.asInstanceOf[InternalRow]
- if (other eq null) {
- return false
- }
-
- val len = numFields
- if (len != other.numFields) {
- return false
- }
-
- var i = 0
- while (i < len) {
- if (isNullAt(i) != other.isNullAt(i)) {
- return false
- }
- if (!isNullAt(i)) {
- val o1 = genericGet(i)
- val o2 = other.genericGet(i)
- o1 match {
- case b1: Array[Byte] =>
- if (!o2.isInstanceOf[Array[Byte]] ||
- !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
- return false
- }
- case f1: Float if java.lang.Float.isNaN(f1) =>
- if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) {
- return false
- }
- case d1: Double if java.lang.Double.isNaN(d1) =>
- if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) {
- return false
- }
- case _ => if (o1 != o2) {
- return false
- }
- }
- }
- i += 1
- }
- true
- }
-
- // Custom hashCode function that matches the efficient code generated version.
- override def hashCode: Int = {
- var result: Int = 37
- var i = 0
- val len = numFields
- while (i < len) {
- val update: Int =
- if (isNullAt(i)) {
- 0
- } else {
- genericGet(i) match {
- case b: Boolean => if (b) 0 else 1
- case b: Byte => b.toInt
- case s: Short => s.toInt
- case i: Int => i
- case l: Long => (l ^ (l >>> 32)).toInt
- case f: Float => java.lang.Float.floatToIntBits(f)
- case d: Double =>
- val b = java.lang.Double.doubleToLongBits(d)
- (b ^ (b >>> 32)).toInt
- case a: Array[Byte] => java.util.Arrays.hashCode(a)
- case other => other.hashCode()
- }
- }
- result = 37 * result + update
- i += 1
- }
- result
- }
-
/* ---------------------- utility methods for Scala ---------------------- */
/**
* Return a Scala Seq representing the row. Elements are placed in the same order in the Seq.
*/
- // todo: remove this as it needs the generic getter
- def toSeq: Seq[Any] = {
- val n = numFields
- val values = new Array[Any](n)
+ def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = {
+ val len = numFields
+ assert(len == fieldTypes.length)
+
+ val values = new Array[Any](len)
var i = 0
- while (i < n) {
- values.update(i, genericGet(i))
+ while (i < len) {
+ values(i) = get(i, fieldTypes(i))
i += 1
}
values
}
- /** Displays all elements of this sequence in a string (without a separator). */
- def mkString: String = toSeq.mkString
-
- /** Displays all elements of this sequence in a string using a separator string. */
- def mkString(sep: String): String = toSeq.mkString(sep)
-
- /**
- * Displays all elements of this traversable or iterator in a string using
- * start, end, and separator strings.
- */
- def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end)
+ def toSeq(schema: StructType): Seq[Any] = toSeq(schema.map(_.dataType))
}
object InternalRow {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 4296b4b123..59ce7fc4f2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -203,7 +203,11 @@ class JoinedRow extends InternalRow {
this
}
- override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
+ override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = {
+ assert(fieldTypes.length == row1.numFields + row2.numFields)
+ val (left, right) = fieldTypes.splitAt(row1.numFields)
+ row1.toSeq(left) ++ row2.toSeq(right)
+ }
override def numFields: Int = row1.numFields + row2.numFields
@@ -276,11 +280,11 @@ class JoinedRow extends InternalRow {
if ((row1 eq null) && (row2 eq null)) {
"[ empty row ]"
} else if (row1 eq null) {
- row2.mkString("[", ",", "]")
+ row2.toString
} else if (row2 eq null) {
- row1.mkString("[", ",", "]")
+ row1.toString
} else {
- mkString("[", ",", "]")
+ s"{${row1.toString} + ${row2.toString}}"
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
index b94df6bd66..4f56f94bd4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
@@ -192,7 +192,8 @@ final class MutableAny extends MutableValue {
* based on the dataTypes of each column. The intent is to decrease garbage when modifying the
* values of primitive columns.
*/
-final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableRow {
+final class SpecificMutableRow(val values: Array[MutableValue])
+ extends MutableRow with BaseGenericInternalRow {
def this(dataTypes: Seq[DataType]) =
this(
@@ -213,8 +214,6 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
override def numFields: Int = values.length
- override def toSeq: Seq[Any] = values.map(_.boxed)
-
override def setNullAt(i: Int): Unit = {
values(i).isNull = true
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index c04fe734d5..c744e84d82 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
@@ -25,6 +26,8 @@ import org.apache.spark.sql.types._
*/
abstract class BaseProjection extends Projection {}
+abstract class CodeGenMutableRow extends MutableRow with BaseGenericInternalRow
+
/**
* Generates bytecode that produces a new [[InternalRow]] object based on a fixed set of input
* [[Expression Expressions]] and a given input [[InternalRow]]. The returned [[InternalRow]]
@@ -171,7 +174,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
return new SpecificRow((InternalRow) r);
}
- final class SpecificRow extends ${classOf[MutableRow].getName} {
+ final class SpecificRow extends ${classOf[CodeGenMutableRow].getName} {
$columns
@@ -184,7 +187,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
public void setNullAt(int i) { nullBits[i] = true; }
public boolean isNullAt(int i) { return nullBits[i]; }
- protected Object genericGet(int i) {
+ @Override
+ public Object genericGet(int i) {
if (isNullAt(i)) return null;
switch (i) {
$getCases
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index fd42fac3d2..11d10b2d8a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -23,6 +23,130 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
/**
+ * An extended version of [[InternalRow]] that implements all special getters, toString
+ * and equals/hashCode by `genericGet`.
+ */
+trait BaseGenericInternalRow extends InternalRow {
+
+ protected def genericGet(ordinal: Int): Any
+
+ // default implementation (slow)
+ private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T]
+ override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null
+ override def get(ordinal: Int, dataType: DataType): AnyRef = getAs(ordinal)
+ override def getBoolean(ordinal: Int): Boolean = getAs(ordinal)
+ override def getByte(ordinal: Int): Byte = getAs(ordinal)
+ override def getShort(ordinal: Int): Short = getAs(ordinal)
+ override def getInt(ordinal: Int): Int = getAs(ordinal)
+ override def getLong(ordinal: Int): Long = getAs(ordinal)
+ override def getFloat(ordinal: Int): Float = getAs(ordinal)
+ override def getDouble(ordinal: Int): Double = getAs(ordinal)
+ override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal)
+ override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal)
+ override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal)
+ override def getArray(ordinal: Int): ArrayData = getAs(ordinal)
+ override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal)
+ override def getMap(ordinal: Int): MapData = getAs(ordinal)
+ override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal)
+
+ override def toString(): String = {
+ if (numFields == 0) {
+ "[empty row]"
+ } else {
+ val sb = new StringBuilder
+ sb.append("[")
+ sb.append(genericGet(0))
+ val len = numFields
+ var i = 1
+ while (i < len) {
+ sb.append(",")
+ sb.append(genericGet(i))
+ i += 1
+ }
+ sb.append("]")
+ sb.toString()
+ }
+ }
+
+ override def equals(o: Any): Boolean = {
+ if (!o.isInstanceOf[BaseGenericInternalRow]) {
+ return false
+ }
+
+ val other = o.asInstanceOf[BaseGenericInternalRow]
+ if (other eq null) {
+ return false
+ }
+
+ val len = numFields
+ if (len != other.numFields) {
+ return false
+ }
+
+ var i = 0
+ while (i < len) {
+ if (isNullAt(i) != other.isNullAt(i)) {
+ return false
+ }
+ if (!isNullAt(i)) {
+ val o1 = genericGet(i)
+ val o2 = other.genericGet(i)
+ o1 match {
+ case b1: Array[Byte] =>
+ if (!o2.isInstanceOf[Array[Byte]] ||
+ !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
+ return false
+ }
+ case f1: Float if java.lang.Float.isNaN(f1) =>
+ if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) {
+ return false
+ }
+ case d1: Double if java.lang.Double.isNaN(d1) =>
+ if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) {
+ return false
+ }
+ case _ => if (o1 != o2) {
+ return false
+ }
+ }
+ }
+ i += 1
+ }
+ true
+ }
+
+ // Custom hashCode function that matches the efficient code generated version.
+ override def hashCode: Int = {
+ var result: Int = 37
+ var i = 0
+ val len = numFields
+ while (i < len) {
+ val update: Int =
+ if (isNullAt(i)) {
+ 0
+ } else {
+ genericGet(i) match {
+ case b: Boolean => if (b) 0 else 1
+ case b: Byte => b.toInt
+ case s: Short => s.toInt
+ case i: Int => i
+ case l: Long => (l ^ (l >>> 32)).toInt
+ case f: Float => java.lang.Float.floatToIntBits(f)
+ case d: Double =>
+ val b = java.lang.Double.doubleToLongBits(d)
+ (b ^ (b >>> 32)).toInt
+ case a: Array[Byte] => java.util.Arrays.hashCode(a)
+ case other => other.hashCode()
+ }
+ }
+ result = 37 * result + update
+ i += 1
+ }
+ result
+ }
+}
+
+/**
* An extended interface to [[InternalRow]] that allows the values for each column to be updated.
* Setting a value through a primitive function implicitly marks that column as not null.
*/
@@ -83,7 +207,7 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType)
* Note that, while the array is not copied, and thus could technically be mutated after creation,
* this is not allowed.
*/
-class GenericInternalRow(private[sql] val values: Array[Any]) extends InternalRow {
+class GenericInternalRow(private[sql] val values: Array[Any]) extends BaseGenericInternalRow {
/** No-arg constructor for serialization. */
protected def this() = this(null)
@@ -91,7 +215,7 @@ class GenericInternalRow(private[sql] val values: Array[Any]) extends InternalRo
override protected def genericGet(ordinal: Int) = values(ordinal)
- override def toSeq: Seq[Any] = values
+ override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values
override def numFields: Int = values.length
@@ -110,7 +234,7 @@ class GenericInternalRowWithSchema(values: Array[Any], val schema: StructType)
def fieldIndex(name: String): Int = schema.fieldIndex(name)
}
-class GenericMutableRow(values: Array[Any]) extends MutableRow {
+class GenericMutableRow(values: Array[Any]) extends MutableRow with BaseGenericInternalRow {
/** No-arg constructor for serialization. */
protected def this() = this(null)
@@ -118,7 +242,7 @@ class GenericMutableRow(values: Array[Any]) extends MutableRow {
override protected def genericGet(ordinal: Int) = values(ordinal)
- override def toSeq: Seq[Any] = values
+ override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values
override def numFields: Int = values.length
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index e310aee221..e323467af5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -87,7 +87,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
val length = 5000
val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1)))
val plan = GenerateMutableProjection.generate(expressions)()
- val actual = plan(new GenericMutableRow(length)).toSeq
+ val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType))
val expected = Seq.fill(length)(true)
if (!checkResult(actual, expected)) {