aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-01-20 15:16:14 -0800
committerMichael Armbrust <michael@databricks.com>2015-01-20 15:16:14 -0800
commitd181c2a1fc40746947b97799b12e7dd8c213fa9c (patch)
tree5905b15884832311d5402767efe1279759245b08 /sql/catalyst
parentbc20a52b34e826895d0dcc1d783c021ebd456ebd (diff)
downloadspark-d181c2a1fc40746947b97799b12e7dd8c213fa9c.tar.gz
spark-d181c2a1fc40746947b97799b12e7dd8c213fa9c.tar.bz2
spark-d181c2a1fc40746947b97799b12e7dd8c213fa9c.zip
[SPARK-5323][SQL] Remove Row's Seq inheritance.
Author: Reynold Xin <rxin@databricks.com> Closes #4115 from rxin/row-seq and squashes the following commits: e33abd8 [Reynold Xin] Fixed compilation error. cceb650 [Reynold Xin] Python test fixes, and removal of WrapDynamic. 0334a52 [Reynold Xin] mkString. 9cdeb7d [Reynold Xin] Hive tests. 15681c2 [Reynold Xin] Fix more test cases. ea9023a [Reynold Xin] Fixed a catalyst test. c5e2cb5 [Reynold Xin] Minor patch up. b9cab7c [Reynold Xin] [SPARK-5323][SQL] Remove Row's Seq inheritance.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala75
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala3
-rwxr-xr-xsql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala310
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala64
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala22
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala2
10 files changed, 266 insertions, 226 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index 208ec92987..41bb4f012f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql
+import scala.util.hashing.MurmurHash3
+
import org.apache.spark.sql.catalyst.expressions.GenericRow
@@ -32,7 +34,7 @@ object Row {
* }
* }}}
*/
- def unapplySeq(row: Row): Some[Seq[Any]] = Some(row)
+ def unapplySeq(row: Row): Some[Seq[Any]] = Some(row.toSeq)
/**
* This method can be used to construct a [[Row]] with the given values.
@@ -43,6 +45,16 @@ object Row {
* This method can be used to construct a [[Row]] from a [[Seq]] of values.
*/
def fromSeq(values: Seq[Any]): Row = new GenericRow(values.toArray)
+
+ def fromTuple(tuple: Product): Row = fromSeq(tuple.productIterator.toSeq)
+
+ /**
+ * Merge multiple rows into a single row, one after another.
+ */
+ def merge(rows: Row*): Row = {
+ // TODO: Improve the performance of this if used in performance critical part.
+ new GenericRow(rows.flatMap(_.toSeq).toArray)
+ }
}
@@ -103,7 +115,13 @@ object Row {
*
* @group row
*/
-trait Row extends Seq[Any] with Serializable {
+trait Row extends Serializable {
+ /** Number of elements in the Row. */
+ def size: Int = length
+
+ /** Number of elements in the Row. */
+ def length: Int
+
/**
* Returns the value at position i. If the value is null, null is returned. The following
* is a mapping between Spark SQL types and return types:
@@ -291,12 +309,61 @@ trait Row extends Seq[Any] with Serializable {
/** Returns true if there are any NULL values in this row. */
def anyNull: Boolean = {
- val l = length
+ val len = length
var i = 0
- while (i < l) {
+ while (i < len) {
if (isNullAt(i)) { return true }
i += 1
}
false
}
+
+ override def equals(that: Any): Boolean = that match {
+ case null => false
+ case that: Row =>
+ if (this.length != that.length) {
+ return false
+ }
+ var i = 0
+ val len = this.length
+ while (i < len) {
+ if (apply(i) != that.apply(i)) {
+ return false
+ }
+ i += 1
+ }
+ true
+ case _ => false
+ }
+
+ override def hashCode: Int = {
+ // Using Scala's Seq hash code implementation.
+ var n = 0
+ var h = MurmurHash3.seqSeed
+ val len = length
+ while (n < len) {
+ h = MurmurHash3.mix(h, apply(n).##)
+ n += 1
+ }
+ MurmurHash3.finalizeHash(h, n)
+ }
+
+ /* ---------------------- utility methods for Scala ---------------------- */
+
+ /**
+ * Return a Scala Seq representing the row. ELements are placed in the same order in the Seq.
+ */
+ def toSeq: Seq[Any]
+
+ /** Displays all elements of this sequence in a string (without a separator). */
+ def mkString: String = toSeq.mkString
+
+ /** Displays all elements of this sequence in a string using a separator string. */
+ def mkString(sep: String): String = toSeq.mkString(sep)
+
+ /**
+ * Displays all elements of this traversable or iterator in a string using
+ * start, end, and separator strings.
+ */
+ def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index d280db83b2..191d16fb10 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -84,8 +84,9 @@ trait ScalaReflection {
}
def convertRowToScala(r: Row, schema: StructType): Row = {
+ // TODO: This is very slow!!!
new GenericRow(
- r.zip(schema.fields.map(_.dataType))
+ r.toSeq.zip(schema.fields.map(_.dataType))
.map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 26c855878d..417659eed5 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -272,9 +272,6 @@ package object dsl {
def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean) =
Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan)
- def sfilter(dynamicUdf: (DynamicRow) => Boolean) =
- Filter(ScalaUdf(dynamicUdf, BooleanType, Seq(WrapDynamic(logicalPlan.output))), logicalPlan)
-
def sample(
fraction: Double,
withReplacement: Boolean = true,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 1a2133bbbc..ece5ee7361 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -407,7 +407,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
val casts = from.fields.zip(to.fields).map {
case (fromField, toField) => cast(fromField.dataType, toField.dataType)
}
- buildCast[Row](_, row => Row(row.zip(casts).map {
+ // TODO: This is very slow!
+ buildCast[Row](_, row => Row(row.toSeq.zip(casts).map {
case (v, cast) => if (v == null) null else cast(v)
}: _*))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index e7e81a21fd..db5d897ee5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -105,45 +105,45 @@ class JoinedRow extends Row {
this
}
- def iterator = row1.iterator ++ row2.iterator
+ override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- def length = row1.length + row2.length
+ override def length = row1.length + row2.length
- def apply(i: Int) =
- if (i < row1.size) row1(i) else row2(i - row1.size)
+ override def apply(i: Int) =
+ if (i < row1.length) row1(i) else row2(i - row1.length)
- def isNullAt(i: Int) =
- if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
+ override def isNullAt(i: Int) =
+ if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
- def getInt(i: Int): Int =
- if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+ override def getInt(i: Int): Int =
+ if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)
- def getLong(i: Int): Long =
- if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+ override def getLong(i: Int): Long =
+ if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)
- def getDouble(i: Int): Double =
- if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+ override def getDouble(i: Int): Double =
+ if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)
- def getBoolean(i: Int): Boolean =
- if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+ override def getBoolean(i: Int): Boolean =
+ if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)
- def getShort(i: Int): Short =
- if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+ override def getShort(i: Int): Short =
+ if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)
- def getByte(i: Int): Byte =
- if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+ override def getByte(i: Int): Byte =
+ if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)
- def getFloat(i: Int): Float =
- if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+ override def getFloat(i: Int): Float =
+ if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- def getString(i: Int): String =
- if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getString(i: Int): String =
+ if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
override def getAs[T](i: Int): T =
- if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+ if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- def copy() = {
- val totalSize = row1.size + row2.size
+ override def copy() = {
+ val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
@@ -154,8 +154,16 @@ class JoinedRow extends Row {
}
override def toString() = {
- val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
- s"[${row.mkString(",")}]"
+ // Make sure toString never throws NullPointerException.
+ if ((row1 eq null) && (row2 eq null)) {
+ "[ empty row ]"
+ } else if (row1 eq null) {
+ row2.mkString("[", ",", "]")
+ } else if (row2 eq null) {
+ row1.mkString("[", ",", "]")
+ } else {
+ mkString("[", ",", "]")
+ }
}
}
@@ -197,45 +205,45 @@ class JoinedRow2 extends Row {
this
}
- def iterator = row1.iterator ++ row2.iterator
+ override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- def length = row1.length + row2.length
+ override def length = row1.length + row2.length
- def apply(i: Int) =
- if (i < row1.size) row1(i) else row2(i - row1.size)
+ override def apply(i: Int) =
+ if (i < row1.length) row1(i) else row2(i - row1.length)
- def isNullAt(i: Int) =
- if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
+ override def isNullAt(i: Int) =
+ if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
- def getInt(i: Int): Int =
- if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+ override def getInt(i: Int): Int =
+ if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)
- def getLong(i: Int): Long =
- if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+ override def getLong(i: Int): Long =
+ if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)
- def getDouble(i: Int): Double =
- if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+ override def getDouble(i: Int): Double =
+ if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)
- def getBoolean(i: Int): Boolean =
- if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+ override def getBoolean(i: Int): Boolean =
+ if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)
- def getShort(i: Int): Short =
- if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+ override def getShort(i: Int): Short =
+ if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)
- def getByte(i: Int): Byte =
- if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+ override def getByte(i: Int): Byte =
+ if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)
- def getFloat(i: Int): Float =
- if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+ override def getFloat(i: Int): Float =
+ if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- def getString(i: Int): String =
- if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getString(i: Int): String =
+ if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
override def getAs[T](i: Int): T =
- if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+ if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- def copy() = {
- val totalSize = row1.size + row2.size
+ override def copy() = {
+ val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
@@ -246,8 +254,16 @@ class JoinedRow2 extends Row {
}
override def toString() = {
- val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
- s"[${row.mkString(",")}]"
+ // Make sure toString never throws NullPointerException.
+ if ((row1 eq null) && (row2 eq null)) {
+ "[ empty row ]"
+ } else if (row1 eq null) {
+ row2.mkString("[", ",", "]")
+ } else if (row2 eq null) {
+ row1.mkString("[", ",", "]")
+ } else {
+ mkString("[", ",", "]")
+ }
}
}
@@ -283,45 +299,45 @@ class JoinedRow3 extends Row {
this
}
- def iterator = row1.iterator ++ row2.iterator
+ override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- def length = row1.length + row2.length
+ override def length = row1.length + row2.length
- def apply(i: Int) =
- if (i < row1.size) row1(i) else row2(i - row1.size)
+ override def apply(i: Int) =
+ if (i < row1.length) row1(i) else row2(i - row1.length)
- def isNullAt(i: Int) =
- if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
+ override def isNullAt(i: Int) =
+ if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
- def getInt(i: Int): Int =
- if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+ override def getInt(i: Int): Int =
+ if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)
- def getLong(i: Int): Long =
- if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+ override def getLong(i: Int): Long =
+ if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)
- def getDouble(i: Int): Double =
- if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+ override def getDouble(i: Int): Double =
+ if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)
- def getBoolean(i: Int): Boolean =
- if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+ override def getBoolean(i: Int): Boolean =
+ if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)
- def getShort(i: Int): Short =
- if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+ override def getShort(i: Int): Short =
+ if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)
- def getByte(i: Int): Byte =
- if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+ override def getByte(i: Int): Byte =
+ if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)
- def getFloat(i: Int): Float =
- if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+ override def getFloat(i: Int): Float =
+ if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- def getString(i: Int): String =
- if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getString(i: Int): String =
+ if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
override def getAs[T](i: Int): T =
- if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+ if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- def copy() = {
- val totalSize = row1.size + row2.size
+ override def copy() = {
+ val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
@@ -332,8 +348,16 @@ class JoinedRow3 extends Row {
}
override def toString() = {
- val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
- s"[${row.mkString(",")}]"
+ // Make sure toString never throws NullPointerException.
+ if ((row1 eq null) && (row2 eq null)) {
+ "[ empty row ]"
+ } else if (row1 eq null) {
+ row2.mkString("[", ",", "]")
+ } else if (row2 eq null) {
+ row1.mkString("[", ",", "]")
+ } else {
+ mkString("[", ",", "]")
+ }
}
}
@@ -369,45 +393,45 @@ class JoinedRow4 extends Row {
this
}
- def iterator = row1.iterator ++ row2.iterator
+ override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- def length = row1.length + row2.length
+ override def length = row1.length + row2.length
- def apply(i: Int) =
- if (i < row1.size) row1(i) else row2(i - row1.size)
+ override def apply(i: Int) =
+ if (i < row1.length) row1(i) else row2(i - row1.length)
- def isNullAt(i: Int) =
- if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
+ override def isNullAt(i: Int) =
+ if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
- def getInt(i: Int): Int =
- if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+ override def getInt(i: Int): Int =
+ if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)
- def getLong(i: Int): Long =
- if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+ override def getLong(i: Int): Long =
+ if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)
- def getDouble(i: Int): Double =
- if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+ override def getDouble(i: Int): Double =
+ if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)
- def getBoolean(i: Int): Boolean =
- if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+ override def getBoolean(i: Int): Boolean =
+ if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)
- def getShort(i: Int): Short =
- if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+ override def getShort(i: Int): Short =
+ if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)
- def getByte(i: Int): Byte =
- if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+ override def getByte(i: Int): Byte =
+ if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)
- def getFloat(i: Int): Float =
- if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+ override def getFloat(i: Int): Float =
+ if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- def getString(i: Int): String =
- if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getString(i: Int): String =
+ if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
override def getAs[T](i: Int): T =
- if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+ if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- def copy() = {
- val totalSize = row1.size + row2.size
+ override def copy() = {
+ val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
@@ -418,8 +442,16 @@ class JoinedRow4 extends Row {
}
override def toString() = {
- val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
- s"[${row.mkString(",")}]"
+ // Make sure toString never throws NullPointerException.
+ if ((row1 eq null) && (row2 eq null)) {
+ "[ empty row ]"
+ } else if (row1 eq null) {
+ row2.mkString("[", ",", "]")
+ } else if (row2 eq null) {
+ row1.mkString("[", ",", "]")
+ } else {
+ mkString("[", ",", "]")
+ }
}
}
@@ -455,45 +487,45 @@ class JoinedRow5 extends Row {
this
}
- def iterator = row1.iterator ++ row2.iterator
+ override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- def length = row1.length + row2.length
+ override def length = row1.length + row2.length
- def apply(i: Int) =
- if (i < row1.size) row1(i) else row2(i - row1.size)
+ override def apply(i: Int) =
+ if (i < row1.length) row1(i) else row2(i - row1.length)
- def isNullAt(i: Int) =
- if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
+ override def isNullAt(i: Int) =
+ if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
- def getInt(i: Int): Int =
- if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+ override def getInt(i: Int): Int =
+ if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)
- def getLong(i: Int): Long =
- if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+ override def getLong(i: Int): Long =
+ if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)
- def getDouble(i: Int): Double =
- if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+ override def getDouble(i: Int): Double =
+ if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)
- def getBoolean(i: Int): Boolean =
- if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+ override def getBoolean(i: Int): Boolean =
+ if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)
- def getShort(i: Int): Short =
- if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+ override def getShort(i: Int): Short =
+ if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)
- def getByte(i: Int): Byte =
- if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+ override def getByte(i: Int): Byte =
+ if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)
- def getFloat(i: Int): Float =
- if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+ override def getFloat(i: Int): Float =
+ if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- def getString(i: Int): String =
- if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getString(i: Int): String =
+ if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
override def getAs[T](i: Int): T =
- if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+ if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- def copy() = {
- val totalSize = row1.size + row2.size
+ override def copy() = {
+ val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
@@ -504,7 +536,15 @@ class JoinedRow5 extends Row {
}
override def toString() = {
- val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
- s"[${row.mkString(",")}]"
+ // Make sure toString never throws NullPointerException.
+ if ((row1 eq null) && (row2 eq null)) {
+ "[ empty row ]"
+ } else if (row1 eq null) {
+ row2.mkString("[", ",", "]")
+ } else if (row2 eq null) {
+ row1.mkString("[", ",", "]")
+ } else {
+ mkString("[", ",", "]")
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
index 37d9f0ed5c..7434165f65 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
@@ -209,6 +209,8 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
override def length: Int = values.length
+ override def toSeq: Seq[Any] = values.map(_.boxed).toSeq
+
override def setNullAt(i: Int): Unit = {
values(i).isNull = true
}
@@ -231,8 +233,6 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
if (value == null) setNullAt(ordinal) else values(ordinal).update(value)
}
- override def iterator: Iterator[Any] = values.map(_.boxed).iterator
-
override def setString(ordinal: Int, value: String) = update(ordinal, value)
override def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
deleted file mode 100644
index e2f5c7332d..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import scala.language.dynamics
-
-import org.apache.spark.sql.types.DataType
-
-/**
- * The data type representing [[DynamicRow]] values.
- */
-case object DynamicType extends DataType {
-
- /**
- * The default size of a value of the DynamicType is 4096 bytes.
- */
- override def defaultSize: Int = 4096
-}
-
-/**
- * Wrap a [[Row]] as a [[DynamicRow]].
- */
-case class WrapDynamic(children: Seq[Attribute]) extends Expression {
- type EvaluatedType = DynamicRow
-
- def nullable = false
-
- def dataType = DynamicType
-
- override def eval(input: Row): DynamicRow = input match {
- // Avoid copy for generic rows.
- case g: GenericRow => new DynamicRow(children, g.values)
- case otherRowType => new DynamicRow(children, otherRowType.toArray)
- }
-}
-
-/**
- * DynamicRows use scala's Dynamic trait to emulate an ORM of in a dynamically typed language.
- * Since the type of the column is not known at compile time, all attributes are converted to
- * strings before being passed to the function.
- */
-class DynamicRow(val schema: Seq[Attribute], values: Array[Any])
- extends GenericRow(values) with Dynamic {
-
- def selectDynamic(attributeName: String): String = {
- val ordinal = schema.indexWhere(_.name == attributeName)
- values(ordinal).toString
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index cc97cb4f50..69397a73a8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -77,14 +77,6 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
""".children : Seq[Tree]
}
- val iteratorFunction = {
- val allColumns = (0 until expressions.size).map { i =>
- val iLit = ru.Literal(Constant(i))
- q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }"
- }
- q"override def iterator = Iterator[Any](..$allColumns)"
- }
-
val accessorFailure = q"""scala.sys.error("Invalid ordinal:" + i)"""
val applyFunction = {
val cases = (0 until expressions.size).map { i =>
@@ -191,20 +183,26 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
}
"""
+ val allColumns = (0 until expressions.size).map { i =>
+ val iLit = ru.Literal(Constant(i))
+ q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }"
+ }
+
val copyFunction =
- q"""
- override def copy() = new $genericRowType(this.toArray)
- """
+ q"override def copy() = new $genericRowType(Array[Any](..$allColumns))"
+
+ val toSeqFunction =
+ q"override def toSeq: Seq[Any] = Seq(..$allColumns)"
val classBody =
nullFunctions ++ (
lengthDef +:
- iteratorFunction +:
applyFunction +:
updateFunction +:
equalsFunction +:
hashCodeFunction +:
copyFunction +:
+ toSeqFunction +:
(tupleElements ++ specificAccessorFunctions ++ specificMutatorFunctions))
val code = q"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index c22b842684..8df150e2f8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -44,7 +44,7 @@ trait MutableRow extends Row {
*/
object EmptyRow extends Row {
override def apply(i: Int): Any = throw new UnsupportedOperationException
- override def iterator = Iterator.empty
+ override def toSeq = Seq.empty
override def length = 0
override def isNullAt(i: Int): Boolean = throw new UnsupportedOperationException
override def getInt(i: Int): Int = throw new UnsupportedOperationException
@@ -70,7 +70,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
def this(size: Int) = this(new Array[Any](size))
- override def iterator = values.iterator
+ override def toSeq = values.toSeq
override def length = values.length
@@ -119,7 +119,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
}
// Custom hashCode function that matches the efficient code generated version.
- override def hashCode(): Int = {
+ override def hashCode: Int = {
var result: Int = 37
var i = 0
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 6df5db4c80..5138942a55 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -244,7 +244,7 @@ class ScalaReflectionSuite extends FunSuite {
test("convert PrimitiveData to catalyst") {
val data = PrimitiveData(1, 1, 1, 1, 1, 1, true)
- val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true)
+ val convertedData = Row(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true)
val dataType = schemaFor[PrimitiveData].dataType
assert(convertToCatalyst(data, dataType) === convertedData)
}