aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-01-20 15:16:14 -0800
committerMichael Armbrust <michael@databricks.com>2015-01-20 15:16:14 -0800
commitd181c2a1fc40746947b97799b12e7dd8c213fa9c (patch)
tree5905b15884832311d5402767efe1279759245b08 /sql
parentbc20a52b34e826895d0dcc1d783c021ebd456ebd (diff)
downloadspark-d181c2a1fc40746947b97799b12e7dd8c213fa9c.tar.gz
spark-d181c2a1fc40746947b97799b12e7dd8c213fa9c.tar.bz2
spark-d181c2a1fc40746947b97799b12e7dd8c213fa9c.zip
[SPARK-5323][SQL] Remove Row's Seq inheritance.
Author: Reynold Xin <rxin@databricks.com> Closes #4115 from rxin/row-seq and squashes the following commits: e33abd8 [Reynold Xin] Fixed compilation error. cceb650 [Reynold Xin] Python test fixes, and removal of WrapDynamic. 0334a52 [Reynold Xin] mkString. 9cdeb7d [Reynold Xin] Hive tests. 15681c2 [Reynold Xin] Fix more test cases. ea9023a [Reynold Xin] Fixed a catalyst test. c5e2cb5 [Reynold Xin] Minor patch up. b9cab7c [Reynold Xin] [SPARK-5323][SQL] Remove Row's Seq inheritance.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala75
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala3
-rwxr-xr-xsql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala310
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala64
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala22
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala146
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala242
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala31
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala416
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala185
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala91
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala28
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala4
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala8
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala20
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala48
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala8
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala6
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala12
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala8
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala34
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala12
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala18
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala58
47 files changed, 1018 insertions, 956 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index 208ec92987..41bb4f012f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql
+import scala.util.hashing.MurmurHash3
+
import org.apache.spark.sql.catalyst.expressions.GenericRow
@@ -32,7 +34,7 @@ object Row {
* }
* }}}
*/
- def unapplySeq(row: Row): Some[Seq[Any]] = Some(row)
+ def unapplySeq(row: Row): Some[Seq[Any]] = Some(row.toSeq)
/**
* This method can be used to construct a [[Row]] with the given values.
@@ -43,6 +45,16 @@ object Row {
* This method can be used to construct a [[Row]] from a [[Seq]] of values.
*/
def fromSeq(values: Seq[Any]): Row = new GenericRow(values.toArray)
+
+ def fromTuple(tuple: Product): Row = fromSeq(tuple.productIterator.toSeq)
+
+ /**
+ * Merge multiple rows into a single row, one after another.
+ */
+ def merge(rows: Row*): Row = {
+ // TODO: Improve the performance of this if used in performance critical part.
+ new GenericRow(rows.flatMap(_.toSeq).toArray)
+ }
}
@@ -103,7 +115,13 @@ object Row {
*
* @group row
*/
-trait Row extends Seq[Any] with Serializable {
+trait Row extends Serializable {
+ /** Number of elements in the Row. */
+ def size: Int = length
+
+ /** Number of elements in the Row. */
+ def length: Int
+
/**
* Returns the value at position i. If the value is null, null is returned. The following
* is a mapping between Spark SQL types and return types:
@@ -291,12 +309,61 @@ trait Row extends Seq[Any] with Serializable {
/** Returns true if there are any NULL values in this row. */
def anyNull: Boolean = {
- val l = length
+ val len = length
var i = 0
- while (i < l) {
+ while (i < len) {
if (isNullAt(i)) { return true }
i += 1
}
false
}
+
+ override def equals(that: Any): Boolean = that match {
+ case null => false
+ case that: Row =>
+ if (this.length != that.length) {
+ return false
+ }
+ var i = 0
+ val len = this.length
+ while (i < len) {
+ if (apply(i) != that.apply(i)) {
+ return false
+ }
+ i += 1
+ }
+ true
+ case _ => false
+ }
+
+ override def hashCode: Int = {
+ // Using Scala's Seq hash code implementation.
+ var n = 0
+ var h = MurmurHash3.seqSeed
+ val len = length
+ while (n < len) {
+ h = MurmurHash3.mix(h, apply(n).##)
+ n += 1
+ }
+ MurmurHash3.finalizeHash(h, n)
+ }
+
+ /* ---------------------- utility methods for Scala ---------------------- */
+
+ /**
+ * Return a Scala Seq representing the row. ELements are placed in the same order in the Seq.
+ */
+ def toSeq: Seq[Any]
+
+ /** Displays all elements of this sequence in a string (without a separator). */
+ def mkString: String = toSeq.mkString
+
+ /** Displays all elements of this sequence in a string using a separator string. */
+ def mkString(sep: String): String = toSeq.mkString(sep)
+
+ /**
+ * Displays all elements of this traversable or iterator in a string using
+ * start, end, and separator strings.
+ */
+ def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index d280db83b2..191d16fb10 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -84,8 +84,9 @@ trait ScalaReflection {
}
def convertRowToScala(r: Row, schema: StructType): Row = {
+ // TODO: This is very slow!!!
new GenericRow(
- r.zip(schema.fields.map(_.dataType))
+ r.toSeq.zip(schema.fields.map(_.dataType))
.map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 26c855878d..417659eed5 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -272,9 +272,6 @@ package object dsl {
def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean) =
Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan)
- def sfilter(dynamicUdf: (DynamicRow) => Boolean) =
- Filter(ScalaUdf(dynamicUdf, BooleanType, Seq(WrapDynamic(logicalPlan.output))), logicalPlan)
-
def sample(
fraction: Double,
withReplacement: Boolean = true,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 1a2133bbbc..ece5ee7361 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -407,7 +407,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
val casts = from.fields.zip(to.fields).map {
case (fromField, toField) => cast(fromField.dataType, toField.dataType)
}
- buildCast[Row](_, row => Row(row.zip(casts).map {
+ // TODO: This is very slow!
+ buildCast[Row](_, row => Row(row.toSeq.zip(casts).map {
case (v, cast) => if (v == null) null else cast(v)
}: _*))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index e7e81a21fd..db5d897ee5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -105,45 +105,45 @@ class JoinedRow extends Row {
this
}
- def iterator = row1.iterator ++ row2.iterator
+ override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- def length = row1.length + row2.length
+ override def length = row1.length + row2.length
- def apply(i: Int) =
- if (i < row1.size) row1(i) else row2(i - row1.size)
+ override def apply(i: Int) =
+ if (i < row1.length) row1(i) else row2(i - row1.length)
- def isNullAt(i: Int) =
- if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
+ override def isNullAt(i: Int) =
+ if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
- def getInt(i: Int): Int =
- if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+ override def getInt(i: Int): Int =
+ if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)
- def getLong(i: Int): Long =
- if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+ override def getLong(i: Int): Long =
+ if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)
- def getDouble(i: Int): Double =
- if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+ override def getDouble(i: Int): Double =
+ if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)
- def getBoolean(i: Int): Boolean =
- if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+ override def getBoolean(i: Int): Boolean =
+ if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)
- def getShort(i: Int): Short =
- if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+ override def getShort(i: Int): Short =
+ if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)
- def getByte(i: Int): Byte =
- if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+ override def getByte(i: Int): Byte =
+ if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)
- def getFloat(i: Int): Float =
- if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+ override def getFloat(i: Int): Float =
+ if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- def getString(i: Int): String =
- if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getString(i: Int): String =
+ if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
override def getAs[T](i: Int): T =
- if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+ if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- def copy() = {
- val totalSize = row1.size + row2.size
+ override def copy() = {
+ val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
@@ -154,8 +154,16 @@ class JoinedRow extends Row {
}
override def toString() = {
- val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
- s"[${row.mkString(",")}]"
+ // Make sure toString never throws NullPointerException.
+ if ((row1 eq null) && (row2 eq null)) {
+ "[ empty row ]"
+ } else if (row1 eq null) {
+ row2.mkString("[", ",", "]")
+ } else if (row2 eq null) {
+ row1.mkString("[", ",", "]")
+ } else {
+ mkString("[", ",", "]")
+ }
}
}
@@ -197,45 +205,45 @@ class JoinedRow2 extends Row {
this
}
- def iterator = row1.iterator ++ row2.iterator
+ override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- def length = row1.length + row2.length
+ override def length = row1.length + row2.length
- def apply(i: Int) =
- if (i < row1.size) row1(i) else row2(i - row1.size)
+ override def apply(i: Int) =
+ if (i < row1.length) row1(i) else row2(i - row1.length)
- def isNullAt(i: Int) =
- if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
+ override def isNullAt(i: Int) =
+ if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
- def getInt(i: Int): Int =
- if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+ override def getInt(i: Int): Int =
+ if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)
- def getLong(i: Int): Long =
- if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+ override def getLong(i: Int): Long =
+ if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)
- def getDouble(i: Int): Double =
- if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+ override def getDouble(i: Int): Double =
+ if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)
- def getBoolean(i: Int): Boolean =
- if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+ override def getBoolean(i: Int): Boolean =
+ if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)
- def getShort(i: Int): Short =
- if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+ override def getShort(i: Int): Short =
+ if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)
- def getByte(i: Int): Byte =
- if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+ override def getByte(i: Int): Byte =
+ if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)
- def getFloat(i: Int): Float =
- if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+ override def getFloat(i: Int): Float =
+ if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- def getString(i: Int): String =
- if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getString(i: Int): String =
+ if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
override def getAs[T](i: Int): T =
- if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+ if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- def copy() = {
- val totalSize = row1.size + row2.size
+ override def copy() = {
+ val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
@@ -246,8 +254,16 @@ class JoinedRow2 extends Row {
}
override def toString() = {
- val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
- s"[${row.mkString(",")}]"
+ // Make sure toString never throws NullPointerException.
+ if ((row1 eq null) && (row2 eq null)) {
+ "[ empty row ]"
+ } else if (row1 eq null) {
+ row2.mkString("[", ",", "]")
+ } else if (row2 eq null) {
+ row1.mkString("[", ",", "]")
+ } else {
+ mkString("[", ",", "]")
+ }
}
}
@@ -283,45 +299,45 @@ class JoinedRow3 extends Row {
this
}
- def iterator = row1.iterator ++ row2.iterator
+ override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- def length = row1.length + row2.length
+ override def length = row1.length + row2.length
- def apply(i: Int) =
- if (i < row1.size) row1(i) else row2(i - row1.size)
+ override def apply(i: Int) =
+ if (i < row1.length) row1(i) else row2(i - row1.length)
- def isNullAt(i: Int) =
- if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
+ override def isNullAt(i: Int) =
+ if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
- def getInt(i: Int): Int =
- if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+ override def getInt(i: Int): Int =
+ if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)
- def getLong(i: Int): Long =
- if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+ override def getLong(i: Int): Long =
+ if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)
- def getDouble(i: Int): Double =
- if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+ override def getDouble(i: Int): Double =
+ if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)
- def getBoolean(i: Int): Boolean =
- if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+ override def getBoolean(i: Int): Boolean =
+ if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)
- def getShort(i: Int): Short =
- if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+ override def getShort(i: Int): Short =
+ if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)
- def getByte(i: Int): Byte =
- if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+ override def getByte(i: Int): Byte =
+ if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)
- def getFloat(i: Int): Float =
- if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+ override def getFloat(i: Int): Float =
+ if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- def getString(i: Int): String =
- if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getString(i: Int): String =
+ if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
override def getAs[T](i: Int): T =
- if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+ if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- def copy() = {
- val totalSize = row1.size + row2.size
+ override def copy() = {
+ val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
@@ -332,8 +348,16 @@ class JoinedRow3 extends Row {
}
override def toString() = {
- val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
- s"[${row.mkString(",")}]"
+ // Make sure toString never throws NullPointerException.
+ if ((row1 eq null) && (row2 eq null)) {
+ "[ empty row ]"
+ } else if (row1 eq null) {
+ row2.mkString("[", ",", "]")
+ } else if (row2 eq null) {
+ row1.mkString("[", ",", "]")
+ } else {
+ mkString("[", ",", "]")
+ }
}
}
@@ -369,45 +393,45 @@ class JoinedRow4 extends Row {
this
}
- def iterator = row1.iterator ++ row2.iterator
+ override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- def length = row1.length + row2.length
+ override def length = row1.length + row2.length
- def apply(i: Int) =
- if (i < row1.size) row1(i) else row2(i - row1.size)
+ override def apply(i: Int) =
+ if (i < row1.length) row1(i) else row2(i - row1.length)
- def isNullAt(i: Int) =
- if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
+ override def isNullAt(i: Int) =
+ if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
- def getInt(i: Int): Int =
- if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+ override def getInt(i: Int): Int =
+ if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)
- def getLong(i: Int): Long =
- if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+ override def getLong(i: Int): Long =
+ if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)
- def getDouble(i: Int): Double =
- if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+ override def getDouble(i: Int): Double =
+ if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)
- def getBoolean(i: Int): Boolean =
- if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+ override def getBoolean(i: Int): Boolean =
+ if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)
- def getShort(i: Int): Short =
- if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+ override def getShort(i: Int): Short =
+ if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)
- def getByte(i: Int): Byte =
- if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+ override def getByte(i: Int): Byte =
+ if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)
- def getFloat(i: Int): Float =
- if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+ override def getFloat(i: Int): Float =
+ if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- def getString(i: Int): String =
- if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getString(i: Int): String =
+ if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
override def getAs[T](i: Int): T =
- if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+ if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- def copy() = {
- val totalSize = row1.size + row2.size
+ override def copy() = {
+ val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
@@ -418,8 +442,16 @@ class JoinedRow4 extends Row {
}
override def toString() = {
- val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
- s"[${row.mkString(",")}]"
+ // Make sure toString never throws NullPointerException.
+ if ((row1 eq null) && (row2 eq null)) {
+ "[ empty row ]"
+ } else if (row1 eq null) {
+ row2.mkString("[", ",", "]")
+ } else if (row2 eq null) {
+ row1.mkString("[", ",", "]")
+ } else {
+ mkString("[", ",", "]")
+ }
}
}
@@ -455,45 +487,45 @@ class JoinedRow5 extends Row {
this
}
- def iterator = row1.iterator ++ row2.iterator
+ override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- def length = row1.length + row2.length
+ override def length = row1.length + row2.length
- def apply(i: Int) =
- if (i < row1.size) row1(i) else row2(i - row1.size)
+ override def apply(i: Int) =
+ if (i < row1.length) row1(i) else row2(i - row1.length)
- def isNullAt(i: Int) =
- if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
+ override def isNullAt(i: Int) =
+ if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
- def getInt(i: Int): Int =
- if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+ override def getInt(i: Int): Int =
+ if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)
- def getLong(i: Int): Long =
- if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+ override def getLong(i: Int): Long =
+ if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)
- def getDouble(i: Int): Double =
- if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+ override def getDouble(i: Int): Double =
+ if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)
- def getBoolean(i: Int): Boolean =
- if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+ override def getBoolean(i: Int): Boolean =
+ if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)
- def getShort(i: Int): Short =
- if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+ override def getShort(i: Int): Short =
+ if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)
- def getByte(i: Int): Byte =
- if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+ override def getByte(i: Int): Byte =
+ if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)
- def getFloat(i: Int): Float =
- if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+ override def getFloat(i: Int): Float =
+ if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- def getString(i: Int): String =
- if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getString(i: Int): String =
+ if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
override def getAs[T](i: Int): T =
- if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+ if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- def copy() = {
- val totalSize = row1.size + row2.size
+ override def copy() = {
+ val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
@@ -504,7 +536,15 @@ class JoinedRow5 extends Row {
}
override def toString() = {
- val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
- s"[${row.mkString(",")}]"
+ // Make sure toString never throws NullPointerException.
+ if ((row1 eq null) && (row2 eq null)) {
+ "[ empty row ]"
+ } else if (row1 eq null) {
+ row2.mkString("[", ",", "]")
+ } else if (row2 eq null) {
+ row1.mkString("[", ",", "]")
+ } else {
+ mkString("[", ",", "]")
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
index 37d9f0ed5c..7434165f65 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
@@ -209,6 +209,8 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
override def length: Int = values.length
+ override def toSeq: Seq[Any] = values.map(_.boxed).toSeq
+
override def setNullAt(i: Int): Unit = {
values(i).isNull = true
}
@@ -231,8 +233,6 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
if (value == null) setNullAt(ordinal) else values(ordinal).update(value)
}
- override def iterator: Iterator[Any] = values.map(_.boxed).iterator
-
override def setString(ordinal: Int, value: String) = update(ordinal, value)
override def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
deleted file mode 100644
index e2f5c7332d..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import scala.language.dynamics
-
-import org.apache.spark.sql.types.DataType
-
-/**
- * The data type representing [[DynamicRow]] values.
- */
-case object DynamicType extends DataType {
-
- /**
- * The default size of a value of the DynamicType is 4096 bytes.
- */
- override def defaultSize: Int = 4096
-}
-
-/**
- * Wrap a [[Row]] as a [[DynamicRow]].
- */
-case class WrapDynamic(children: Seq[Attribute]) extends Expression {
- type EvaluatedType = DynamicRow
-
- def nullable = false
-
- def dataType = DynamicType
-
- override def eval(input: Row): DynamicRow = input match {
- // Avoid copy for generic rows.
- case g: GenericRow => new DynamicRow(children, g.values)
- case otherRowType => new DynamicRow(children, otherRowType.toArray)
- }
-}
-
-/**
- * DynamicRows use scala's Dynamic trait to emulate an ORM of in a dynamically typed language.
- * Since the type of the column is not known at compile time, all attributes are converted to
- * strings before being passed to the function.
- */
-class DynamicRow(val schema: Seq[Attribute], values: Array[Any])
- extends GenericRow(values) with Dynamic {
-
- def selectDynamic(attributeName: String): String = {
- val ordinal = schema.indexWhere(_.name == attributeName)
- values(ordinal).toString
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index cc97cb4f50..69397a73a8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -77,14 +77,6 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
""".children : Seq[Tree]
}
- val iteratorFunction = {
- val allColumns = (0 until expressions.size).map { i =>
- val iLit = ru.Literal(Constant(i))
- q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }"
- }
- q"override def iterator = Iterator[Any](..$allColumns)"
- }
-
val accessorFailure = q"""scala.sys.error("Invalid ordinal:" + i)"""
val applyFunction = {
val cases = (0 until expressions.size).map { i =>
@@ -191,20 +183,26 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
}
"""
+ val allColumns = (0 until expressions.size).map { i =>
+ val iLit = ru.Literal(Constant(i))
+ q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }"
+ }
+
val copyFunction =
- q"""
- override def copy() = new $genericRowType(this.toArray)
- """
+ q"override def copy() = new $genericRowType(Array[Any](..$allColumns))"
+
+ val toSeqFunction =
+ q"override def toSeq: Seq[Any] = Seq(..$allColumns)"
val classBody =
nullFunctions ++ (
lengthDef +:
- iteratorFunction +:
applyFunction +:
updateFunction +:
equalsFunction +:
hashCodeFunction +:
copyFunction +:
+ toSeqFunction +:
(tupleElements ++ specificAccessorFunctions ++ specificMutatorFunctions))
val code = q"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index c22b842684..8df150e2f8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -44,7 +44,7 @@ trait MutableRow extends Row {
*/
object EmptyRow extends Row {
override def apply(i: Int): Any = throw new UnsupportedOperationException
- override def iterator = Iterator.empty
+ override def toSeq = Seq.empty
override def length = 0
override def isNullAt(i: Int): Boolean = throw new UnsupportedOperationException
override def getInt(i: Int): Int = throw new UnsupportedOperationException
@@ -70,7 +70,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
def this(size: Int) = this(new Array[Any](size))
- override def iterator = values.iterator
+ override def toSeq = values.toSeq
override def length = values.length
@@ -119,7 +119,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
}
// Custom hashCode function that matches the efficient code generated version.
- override def hashCode(): Int = {
+ override def hashCode: Int = {
var result: Int = 37
var i = 0
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 6df5db4c80..5138942a55 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -244,7 +244,7 @@ class ScalaReflectionSuite extends FunSuite {
test("convert PrimitiveData to catalyst") {
val data = PrimitiveData(1, 1, 1, 1, 1, 1, true)
- val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true)
+ val convertedData = Row(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true)
val dataType = schemaFor[PrimitiveData].dataType
assert(convertToCatalyst(data, dataType) === convertedData)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index ae4d8ba90c..d1e21dffeb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -332,25 +332,6 @@ class SchemaRDD(
/**
* :: Experimental ::
- * Filters tuples using a function over a `Dynamic` version of a given Row. DynamicRows use
- * scala's Dynamic trait to emulate an ORM of in a dynamically typed language. Since the type of
- * the column is not known at compile time, all attributes are converted to strings before
- * being passed to the function.
- *
- * {{{
- * schemaRDD.where(r => r.firstName == "Bob" && r.lastName == "Smith")
- * }}}
- *
- * @group Query
- */
- @Experimental
- def where(dynamicUdf: (DynamicRow) => Boolean) =
- new SchemaRDD(
- sqlContext,
- Filter(ScalaUdf(dynamicUdf, BooleanType, Seq(WrapDynamic(logicalPlan.output))), logicalPlan))
-
- /**
- * :: Experimental ::
* Returns a sampled version of the underlying dataset.
*
* @group Query
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index 065fae3c83..11d5943fb4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -21,7 +21,6 @@ import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
@@ -128,8 +127,7 @@ private[sql] case class InMemoryRelation(
rowCount += 1
}
- val stats = Row.fromSeq(
- columnBuilders.map(_.columnStats.collectedStatistics).foldLeft(Seq.empty[Any])(_ ++ _))
+ val stats = Row.merge(columnBuilders.map(_.columnStats.collectedStatistics) : _*)
batchStats += stats
CachedBatch(columnBuilders.map(_.build().array()), stats)
@@ -271,9 +269,10 @@ private[sql] case class InMemoryColumnarTableScan(
// Extract rows via column accessors
new Iterator[Row] {
+ private[this] val rowLen = nextRow.length
override def next() = {
var i = 0
- while (i < nextRow.length) {
+ while (i < rowLen) {
columnAccessors(i).extractTo(nextRow, i)
i += 1
}
@@ -297,7 +296,7 @@ private[sql] case class InMemoryColumnarTableScan(
cachedBatchIterator.filter { cachedBatch =>
if (!partitionFilter(cachedBatch.stats)) {
def statsString = relation.partitionStatistics.schema
- .zip(cachedBatch.stats)
+ .zip(cachedBatch.stats.toSeq)
.map { case (a, s) => s"${a.name}: $s" }
.mkString(", ")
logInfo(s"Skipping partition based on stats $statsString")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
index 6467324839..68a5b1de76 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
@@ -127,7 +127,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme {
while (from.hasRemaining) {
columnType.extract(from, value, 0)
- if (value.head == currentValue.head) {
+ if (value(0) == currentValue(0)) {
currentRun += 1
} else {
// Writes current run
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index 46245cd5a1..4d7e338e8e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -144,7 +144,7 @@ package object debug {
case (null, _) =>
case (row: Row, StructType(fields)) =>
- row.zip(fields.map(_.dataType)).foreach { case(d,t) => typeCheck(d,t) }
+ row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) }
case (s: Seq[_], ArrayType(elemType, _)) =>
s.foreach(typeCheck(_, elemType))
case (m: Map[_, _], MapType(keyType, valueType, _)) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 7ed64aad10..b85021acc9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -116,9 +116,9 @@ object EvaluatePython {
def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
case (null, _) => null
- case (row: Seq[Any], struct: StructType) =>
+ case (row: Row, struct: StructType) =>
val fields = struct.fields.map(field => field.dataType)
- row.zip(fields).map {
+ row.toSeq.zip(fields).map {
case (obj, dataType) => toJava(obj, dataType)
}.toArray
@@ -143,7 +143,8 @@ object EvaluatePython {
* Convert Row into Java Array (for pickled into Python)
*/
def rowToArray(row: Row, fields: Seq[DataType]): Array[Any] = {
- row.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray
+ // TODO: this is slow!
+ row.toSeq.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray
}
// Converts value to the type specified by the data type.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
index db70a7eac7..9171939f7e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
@@ -458,16 +458,16 @@ private[sql] object JsonRDD extends Logging {
gen.writeEndArray()
case (MapType(kv,vv, _), v: Map[_,_]) =>
- gen.writeStartObject
+ gen.writeStartObject()
v.foreach { p =>
gen.writeFieldName(p._1.toString)
valWriter(vv,p._2)
}
- gen.writeEndObject
+ gen.writeEndObject()
- case (StructType(ty), v: Seq[_]) =>
+ case (StructType(ty), v: Row) =>
gen.writeStartObject()
- ty.zip(v).foreach {
+ ty.zip(v.toSeq).foreach {
case (_, null) =>
case (field, v) =>
gen.writeFieldName(field.name)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
index b4aed04199..9d9150246c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
@@ -66,7 +66,7 @@ private[sql] object CatalystConverter {
// TODO: consider using Array[T] for arrays to avoid boxing of primitive types
type ArrayScalaType[T] = Seq[T]
- type StructScalaType[T] = Seq[T]
+ type StructScalaType[T] = Row
type MapScalaType[K, V] = Map[K, V]
protected[parquet] def createConverter(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index 2bcfe28456..afbfe214f1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -45,28 +45,28 @@ class DslQuerySuite extends QueryTest {
test("agg") {
checkAnswer(
testData2.groupBy('a)('a, sum('b)),
- Seq((1,3),(2,3),(3,3))
+ Seq(Row(1,3), Row(2,3), Row(3,3))
)
checkAnswer(
testData2.groupBy('a)('a, sum('b) as 'totB).aggregate(sum('totB)),
- 9
+ Row(9)
)
checkAnswer(
testData2.aggregate(sum('b)),
- 9
+ Row(9)
)
}
test("convert $\"attribute name\" into unresolved attribute") {
checkAnswer(
testData.where($"key" === 1).select($"value"),
- Seq(Seq("1")))
+ Row("1"))
}
test("convert Scala Symbol 'attrname into unresolved attribute") {
checkAnswer(
testData.where('key === 1).select('value),
- Seq(Seq("1")))
+ Row("1"))
}
test("select *") {
@@ -78,61 +78,61 @@ class DslQuerySuite extends QueryTest {
test("simple select") {
checkAnswer(
testData.where('key === 1).select('value),
- Seq(Seq("1")))
+ Row("1"))
}
test("select with functions") {
checkAnswer(
testData.select(sum('value), avg('value), count(1)),
- Seq(Seq(5050.0, 50.5, 100)))
+ Row(5050.0, 50.5, 100))
checkAnswer(
testData2.select('a + 'b, 'a < 'b),
Seq(
- Seq(2, false),
- Seq(3, true),
- Seq(3, false),
- Seq(4, false),
- Seq(4, false),
- Seq(5, false)))
+ Row(2, false),
+ Row(3, true),
+ Row(3, false),
+ Row(4, false),
+ Row(4, false),
+ Row(5, false)))
checkAnswer(
testData2.select(sumDistinct('a)),
- Seq(Seq(6)))
+ Row(6))
}
test("global sorting") {
checkAnswer(
testData2.orderBy('a.asc, 'b.asc),
- Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2)))
+ Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2)))
checkAnswer(
testData2.orderBy('a.asc, 'b.desc),
- Seq((1,2), (1,1), (2,2), (2,1), (3,2), (3,1)))
+ Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1)))
checkAnswer(
testData2.orderBy('a.desc, 'b.desc),
- Seq((3,2), (3,1), (2,2), (2,1), (1,2), (1,1)))
+ Seq(Row(3,2), Row(3,1), Row(2,2), Row(2,1), Row(1,2), Row(1,1)))
checkAnswer(
testData2.orderBy('a.desc, 'b.asc),
- Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2)))
+ Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2)))
checkAnswer(
arrayData.orderBy('data.getItem(0).asc),
- arrayData.collect().sortBy(_.data(0)).toSeq)
+ arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq)
checkAnswer(
arrayData.orderBy('data.getItem(0).desc),
- arrayData.collect().sortBy(_.data(0)).reverse.toSeq)
+ arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq)
checkAnswer(
- mapData.orderBy('data.getItem(1).asc),
- mapData.collect().sortBy(_.data(1)).toSeq)
+ arrayData.orderBy('data.getItem(1).asc),
+ arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq)
checkAnswer(
- mapData.orderBy('data.getItem(1).desc),
- mapData.collect().sortBy(_.data(1)).reverse.toSeq)
+ arrayData.orderBy('data.getItem(1).desc),
+ arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq)
}
test("partition wide sorting") {
@@ -147,19 +147,19 @@ class DslQuerySuite extends QueryTest {
// (3, 2)
checkAnswer(
testData2.sortBy('a.asc, 'b.asc),
- Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2)))
+ Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2)))
checkAnswer(
testData2.sortBy('a.asc, 'b.desc),
- Seq((1,2), (1,1), (2,1), (2,2), (3,2), (3,1)))
+ Seq(Row(1,2), Row(1,1), Row(2,1), Row(2,2), Row(3,2), Row(3,1)))
checkAnswer(
testData2.sortBy('a.desc, 'b.desc),
- Seq((2,1), (1,2), (1,1), (3,2), (3,1), (2,2)))
+ Seq(Row(2,1), Row(1,2), Row(1,1), Row(3,2), Row(3,1), Row(2,2)))
checkAnswer(
testData2.sortBy('a.desc, 'b.asc),
- Seq((2,1), (1,1), (1,2), (3,1), (3,2), (2,2)))
+ Seq(Row(2,1), Row(1,1), Row(1,2), Row(3,1), Row(3,2), Row(2,2)))
}
test("limit") {
@@ -169,11 +169,11 @@ class DslQuerySuite extends QueryTest {
checkAnswer(
arrayData.limit(1),
- arrayData.take(1).toSeq)
+ arrayData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq)))
checkAnswer(
mapData.limit(1),
- mapData.take(1).toSeq)
+ mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq)))
}
test("SPARK-3395 limit distinct") {
@@ -184,8 +184,8 @@ class DslQuerySuite extends QueryTest {
.registerTempTable("onerow")
checkAnswer(
sql("select * from onerow inner join testData2 on onerow.a = testData2.a"),
- (1, 1, 1, 1) ::
- (1, 1, 1, 2) :: Nil)
+ Row(1, 1, 1, 1) ::
+ Row(1, 1, 1, 2) :: Nil)
}
test("SPARK-3858 generator qualifiers are discarded") {
@@ -193,55 +193,55 @@ class DslQuerySuite extends QueryTest {
arrayData.as('ad)
.generate(Explode("data" :: Nil, 'data), alias = Some("ex"))
.select("ex.data".attr),
- Seq(1, 2, 3, 2, 3, 4).map(Seq(_)))
+ Seq(1, 2, 3, 2, 3, 4).map(Row(_)))
}
test("average") {
checkAnswer(
testData2.aggregate(avg('a)),
- 2.0)
+ Row(2.0))
checkAnswer(
testData2.aggregate(avg('a), sumDistinct('a)), // non-partial
- (2.0, 6.0) :: Nil)
+ Row(2.0, 6.0) :: Nil)
checkAnswer(
decimalData.aggregate(avg('a)),
- new java.math.BigDecimal(2.0))
+ Row(new java.math.BigDecimal(2.0)))
checkAnswer(
decimalData.aggregate(avg('a), sumDistinct('a)), // non-partial
- (new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
+ Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
checkAnswer(
decimalData.aggregate(avg('a cast DecimalType(10, 2))),
- new java.math.BigDecimal(2.0))
+ Row(new java.math.BigDecimal(2.0)))
checkAnswer(
decimalData.aggregate(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial
- (new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
+ Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
}
test("null average") {
checkAnswer(
testData3.aggregate(avg('b)),
- 2.0)
+ Row(2.0))
checkAnswer(
testData3.aggregate(avg('b), countDistinct('b)),
- (2.0, 1) :: Nil)
+ Row(2.0, 1))
checkAnswer(
testData3.aggregate(avg('b), sumDistinct('b)), // non-partial
- (2.0, 2.0) :: Nil)
+ Row(2.0, 2.0))
}
test("zero average") {
checkAnswer(
emptyTableData.aggregate(avg('a)),
- null)
+ Row(null))
checkAnswer(
emptyTableData.aggregate(avg('a), sumDistinct('b)), // non-partial
- (null, null) :: Nil)
+ Row(null, null))
}
test("count") {
@@ -249,28 +249,28 @@ class DslQuerySuite extends QueryTest {
checkAnswer(
testData2.aggregate(count('a), sumDistinct('a)), // non-partial
- (6, 6.0) :: Nil)
+ Row(6, 6.0))
}
test("null count") {
checkAnswer(
testData3.groupBy('a)('a, count('b)),
- Seq((1,0), (2, 1))
+ Seq(Row(1,0), Row(2, 1))
)
checkAnswer(
testData3.groupBy('a)('a, count('a + 'b)),
- Seq((1,0), (2, 1))
+ Seq(Row(1,0), Row(2, 1))
)
checkAnswer(
testData3.aggregate(count('a), count('b), count(1), countDistinct('a), countDistinct('b)),
- (2, 1, 2, 2, 1) :: Nil
+ Row(2, 1, 2, 2, 1)
)
checkAnswer(
testData3.aggregate(count('b), countDistinct('b), sumDistinct('b)), // non-partial
- (1, 1, 2) :: Nil
+ Row(1, 1, 2)
)
}
@@ -279,28 +279,28 @@ class DslQuerySuite extends QueryTest {
checkAnswer(
emptyTableData.aggregate(count('a), sumDistinct('a)), // non-partial
- (0, null) :: Nil)
+ Row(0, null))
}
test("zero sum") {
checkAnswer(
emptyTableData.aggregate(sum('a)),
- null)
+ Row(null))
}
test("zero sum distinct") {
checkAnswer(
emptyTableData.aggregate(sumDistinct('a)),
- null)
+ Row(null))
}
test("except") {
checkAnswer(
lowerCaseData.except(upperCaseData),
- (1, "a") ::
- (2, "b") ::
- (3, "c") ::
- (4, "d") :: Nil)
+ Row(1, "a") ::
+ Row(2, "b") ::
+ Row(3, "c") ::
+ Row(4, "d") :: Nil)
checkAnswer(lowerCaseData.except(lowerCaseData), Nil)
checkAnswer(upperCaseData.except(upperCaseData), Nil)
}
@@ -308,10 +308,10 @@ class DslQuerySuite extends QueryTest {
test("intersect") {
checkAnswer(
lowerCaseData.intersect(lowerCaseData),
- (1, "a") ::
- (2, "b") ::
- (3, "c") ::
- (4, "d") :: Nil)
+ Row(1, "a") ::
+ Row(2, "b") ::
+ Row(3, "c") ::
+ Row(4, "d") :: Nil)
checkAnswer(lowerCaseData.intersect(upperCaseData), Nil)
}
@@ -321,75 +321,75 @@ class DslQuerySuite extends QueryTest {
checkAnswer(
// SELECT *, foo(key, value) FROM testData
testData.select(Star(None), foo.call('key, 'value)).limit(3),
- (1, "1", "11") :: (2, "2", "22") :: (3, "3", "33") :: Nil
+ Row(1, "1", "11") :: Row(2, "2", "22") :: Row(3, "3", "33") :: Nil
)
}
test("sqrt") {
checkAnswer(
testData.select(sqrt('key)).orderBy('key asc),
- (1 to 100).map(n => Seq(math.sqrt(n)))
+ (1 to 100).map(n => Row(math.sqrt(n)))
)
checkAnswer(
testData.select(sqrt('value), 'key).orderBy('key asc, 'value asc),
- (1 to 100).map(n => Seq(math.sqrt(n), n))
+ (1 to 100).map(n => Row(math.sqrt(n), n))
)
checkAnswer(
testData.select(sqrt(Literal(null))),
- (1 to 100).map(_ => Seq(null))
+ (1 to 100).map(_ => Row(null))
)
}
test("abs") {
checkAnswer(
testData.select(abs('key)).orderBy('key asc),
- (1 to 100).map(n => Seq(n))
+ (1 to 100).map(n => Row(n))
)
checkAnswer(
negativeData.select(abs('key)).orderBy('key desc),
- (1 to 100).map(n => Seq(n))
+ (1 to 100).map(n => Row(n))
)
checkAnswer(
testData.select(abs(Literal(null))),
- (1 to 100).map(_ => Seq(null))
+ (1 to 100).map(_ => Row(null))
)
}
test("upper") {
checkAnswer(
lowerCaseData.select(upper('l)),
- ('a' to 'd').map(c => Seq(c.toString.toUpperCase()))
+ ('a' to 'd').map(c => Row(c.toString.toUpperCase()))
)
checkAnswer(
testData.select(upper('value), 'key),
- (1 to 100).map(n => Seq(n.toString, n))
+ (1 to 100).map(n => Row(n.toString, n))
)
checkAnswer(
testData.select(upper(Literal(null))),
- (1 to 100).map(n => Seq(null))
+ (1 to 100).map(n => Row(null))
)
}
test("lower") {
checkAnswer(
upperCaseData.select(lower('L)),
- ('A' to 'F').map(c => Seq(c.toString.toLowerCase()))
+ ('A' to 'F').map(c => Row(c.toString.toLowerCase()))
)
checkAnswer(
testData.select(lower('value), 'key),
- (1 to 100).map(n => Seq(n.toString, n))
+ (1 to 100).map(n => Row(n.toString, n))
)
checkAnswer(
testData.select(lower(Literal(null))),
- (1 to 100).map(n => Seq(null))
+ (1 to 100).map(n => Row(null))
)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index e5ab16f9dd..cd36da7751 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -117,10 +117,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
checkAnswer(
upperCaseData.join(lowerCaseData, Inner).where('n === 'N),
Seq(
- (1, "A", 1, "a"),
- (2, "B", 2, "b"),
- (3, "C", 3, "c"),
- (4, "D", 4, "d")
+ Row(1, "A", 1, "a"),
+ Row(2, "B", 2, "b"),
+ Row(3, "C", 3, "c"),
+ Row(4, "D", 4, "d")
))
}
@@ -128,10 +128,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
checkAnswer(
upperCaseData.join(lowerCaseData, Inner, Some('n === 'N)),
Seq(
- (1, "A", 1, "a"),
- (2, "B", 2, "b"),
- (3, "C", 3, "c"),
- (4, "D", 4, "d")
+ Row(1, "A", 1, "a"),
+ Row(2, "B", 2, "b"),
+ Row(3, "C", 3, "c"),
+ Row(4, "D", 4, "d")
))
}
@@ -140,10 +140,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
val y = testData2.where('a === 1).as('y)
checkAnswer(
x.join(y).where("x.a".attr === "y.a".attr),
- (1,1,1,1) ::
- (1,1,1,2) ::
- (1,2,1,1) ::
- (1,2,1,2) :: Nil
+ Row(1,1,1,1) ::
+ Row(1,1,1,2) ::
+ Row(1,2,1,1) ::
+ Row(1,2,1,2) :: Nil
)
}
@@ -163,54 +163,54 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
checkAnswer(
bigDataX.join(bigDataY).where("x.key".attr === "y.key".attr),
testData.flatMap(
- row => Seq.fill(16)((row ++ row).toSeq)).collect().toSeq)
+ row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq)
}
test("cartisian product join") {
checkAnswer(
testData3.join(testData3),
- (1, null, 1, null) ::
- (1, null, 2, 2) ::
- (2, 2, 1, null) ::
- (2, 2, 2, 2) :: Nil)
+ Row(1, null, 1, null) ::
+ Row(1, null, 2, 2) ::
+ Row(2, 2, 1, null) ::
+ Row(2, 2, 2, 2) :: Nil)
}
test("left outer join") {
checkAnswer(
upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N)),
- (1, "A", 1, "a") ::
- (2, "B", 2, "b") ::
- (3, "C", 3, "c") ::
- (4, "D", 4, "d") ::
- (5, "E", null, null) ::
- (6, "F", null, null) :: Nil)
+ Row(1, "A", 1, "a") ::
+ Row(2, "B", 2, "b") ::
+ Row(3, "C", 3, "c") ::
+ Row(4, "D", 4, "d") ::
+ Row(5, "E", null, null) ::
+ Row(6, "F", null, null) :: Nil)
checkAnswer(
upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'n > 1)),
- (1, "A", null, null) ::
- (2, "B", 2, "b") ::
- (3, "C", 3, "c") ::
- (4, "D", 4, "d") ::
- (5, "E", null, null) ::
- (6, "F", null, null) :: Nil)
+ Row(1, "A", null, null) ::
+ Row(2, "B", 2, "b") ::
+ Row(3, "C", 3, "c") ::
+ Row(4, "D", 4, "d") ::
+ Row(5, "E", null, null) ::
+ Row(6, "F", null, null) :: Nil)
checkAnswer(
upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'N > 1)),
- (1, "A", null, null) ::
- (2, "B", 2, "b") ::
- (3, "C", 3, "c") ::
- (4, "D", 4, "d") ::
- (5, "E", null, null) ::
- (6, "F", null, null) :: Nil)
+ Row(1, "A", null, null) ::
+ Row(2, "B", 2, "b") ::
+ Row(3, "C", 3, "c") ::
+ Row(4, "D", 4, "d") ::
+ Row(5, "E", null, null) ::
+ Row(6, "F", null, null) :: Nil)
checkAnswer(
upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'l > 'L)),
- (1, "A", 1, "a") ::
- (2, "B", 2, "b") ::
- (3, "C", 3, "c") ::
- (4, "D", 4, "d") ::
- (5, "E", null, null) ::
- (6, "F", null, null) :: Nil)
+ Row(1, "A", 1, "a") ::
+ Row(2, "B", 2, "b") ::
+ Row(3, "C", 3, "c") ::
+ Row(4, "D", 4, "d") ::
+ Row(5, "E", null, null) ::
+ Row(6, "F", null, null) :: Nil)
// Make sure we are choosing left.outputPartitioning as the
// outputPartitioning for the outer join operator.
@@ -221,12 +221,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a)
|GROUP BY l.N
""".stripMargin),
- (1, 1) ::
- (2, 1) ::
- (3, 1) ::
- (4, 1) ::
- (5, 1) ::
- (6, 1) :: Nil)
+ Row(1, 1) ::
+ Row(2, 1) ::
+ Row(3, 1) ::
+ Row(4, 1) ::
+ Row(5, 1) ::
+ Row(6, 1) :: Nil)
checkAnswer(
sql(
@@ -235,42 +235,42 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a)
|GROUP BY r.a
""".stripMargin),
- (null, 6) :: Nil)
+ Row(null, 6) :: Nil)
}
test("right outer join") {
checkAnswer(
lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N)),
- (1, "a", 1, "A") ::
- (2, "b", 2, "B") ::
- (3, "c", 3, "C") ::
- (4, "d", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
+ Row(1, "a", 1, "A") ::
+ Row(2, "b", 2, "B") ::
+ Row(3, "c", 3, "C") ::
+ Row(4, "d", 4, "D") ::
+ Row(null, null, 5, "E") ::
+ Row(null, null, 6, "F") :: Nil)
checkAnswer(
lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'n > 1)),
- (null, null, 1, "A") ::
- (2, "b", 2, "B") ::
- (3, "c", 3, "C") ::
- (4, "d", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
+ Row(null, null, 1, "A") ::
+ Row(2, "b", 2, "B") ::
+ Row(3, "c", 3, "C") ::
+ Row(4, "d", 4, "D") ::
+ Row(null, null, 5, "E") ::
+ Row(null, null, 6, "F") :: Nil)
checkAnswer(
lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'N > 1)),
- (null, null, 1, "A") ::
- (2, "b", 2, "B") ::
- (3, "c", 3, "C") ::
- (4, "d", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
+ Row(null, null, 1, "A") ::
+ Row(2, "b", 2, "B") ::
+ Row(3, "c", 3, "C") ::
+ Row(4, "d", 4, "D") ::
+ Row(null, null, 5, "E") ::
+ Row(null, null, 6, "F") :: Nil)
checkAnswer(
lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'l > 'L)),
- (1, "a", 1, "A") ::
- (2, "b", 2, "B") ::
- (3, "c", 3, "C") ::
- (4, "d", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
+ Row(1, "a", 1, "A") ::
+ Row(2, "b", 2, "B") ::
+ Row(3, "c", 3, "C") ::
+ Row(4, "d", 4, "D") ::
+ Row(null, null, 5, "E") ::
+ Row(null, null, 6, "F") :: Nil)
// Make sure we are choosing right.outputPartitioning as the
// outputPartitioning for the outer join operator.
@@ -281,7 +281,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N)
|GROUP BY l.a
""".stripMargin),
- (null, 6) :: Nil)
+ Row(null, 6))
checkAnswer(
sql(
@@ -290,12 +290,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N)
|GROUP BY r.N
""".stripMargin),
- (1, 1) ::
- (2, 1) ::
- (3, 1) ::
- (4, 1) ::
- (5, 1) ::
- (6, 1) :: Nil)
+ Row(1, 1) ::
+ Row(2, 1) ::
+ Row(3, 1) ::
+ Row(4, 1) ::
+ Row(5, 1) ::
+ Row(6, 1) :: Nil)
}
test("full outer join") {
@@ -307,32 +307,32 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
checkAnswer(
left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)),
- (1, "A", null, null) ::
- (2, "B", null, null) ::
- (3, "C", 3, "C") ::
- (4, "D", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
+ Row(1, "A", null, null) ::
+ Row(2, "B", null, null) ::
+ Row(3, "C", 3, "C") ::
+ Row(4, "D", 4, "D") ::
+ Row(null, null, 5, "E") ::
+ Row(null, null, 6, "F") :: Nil)
checkAnswer(
left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))),
- (1, "A", null, null) ::
- (2, "B", null, null) ::
- (3, "C", null, null) ::
- (null, null, 3, "C") ::
- (4, "D", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
+ Row(1, "A", null, null) ::
+ Row(2, "B", null, null) ::
+ Row(3, "C", null, null) ::
+ Row(null, null, 3, "C") ::
+ Row(4, "D", 4, "D") ::
+ Row(null, null, 5, "E") ::
+ Row(null, null, 6, "F") :: Nil)
checkAnswer(
left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))),
- (1, "A", null, null) ::
- (2, "B", null, null) ::
- (3, "C", null, null) ::
- (null, null, 3, "C") ::
- (4, "D", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
+ Row(1, "A", null, null) ::
+ Row(2, "B", null, null) ::
+ Row(3, "C", null, null) ::
+ Row(null, null, 3, "C") ::
+ Row(4, "D", 4, "D") ::
+ Row(null, null, 5, "E") ::
+ Row(null, null, 6, "F") :: Nil)
// Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator.
checkAnswer(
@@ -342,7 +342,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
|GROUP BY l.a
""".stripMargin),
- (null, 10) :: Nil)
+ Row(null, 10))
checkAnswer(
sql(
@@ -351,13 +351,13 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
|GROUP BY r.N
""".stripMargin),
- (1, 1) ::
- (2, 1) ::
- (3, 1) ::
- (4, 1) ::
- (5, 1) ::
- (6, 1) ::
- (null, 4) :: Nil)
+ Row(1, 1) ::
+ Row(2, 1) ::
+ Row(3, 1) ::
+ Row(4, 1) ::
+ Row(5, 1) ::
+ Row(6, 1) ::
+ Row(null, 4) :: Nil)
checkAnswer(
sql(
@@ -366,13 +366,13 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
|GROUP BY l.N
""".stripMargin),
- (1, 1) ::
- (2, 1) ::
- (3, 1) ::
- (4, 1) ::
- (5, 1) ::
- (6, 1) ::
- (null, 4) :: Nil)
+ Row(1, 1) ::
+ Row(2, 1) ::
+ Row(3, 1) ::
+ Row(4, 1) ::
+ Row(5, 1) ::
+ Row(6, 1) ::
+ Row(null, 4) :: Nil)
checkAnswer(
sql(
@@ -381,7 +381,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
|GROUP BY r.a
""".stripMargin),
- (null, 10) :: Nil)
+ Row(null, 10))
}
test("broadcasted left semi join operator selection") {
@@ -412,12 +412,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
test("left semi join") {
val rdd = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a")
checkAnswer(rdd,
- (1, 1) ::
- (1, 2) ::
- (2, 1) ::
- (2, 2) ::
- (3, 1) ::
- (3, 2) :: Nil)
+ Row(1, 1) ::
+ Row(1, 2) ::
+ Row(2, 1) ::
+ Row(2, 2) ::
+ Row(3, 1) ::
+ Row(3, 2) :: Nil)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 68ddecc7f6..42a21c148d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -47,26 +47,17 @@ class QueryTest extends PlanTest {
* @param rdd the [[SchemaRDD]] to be executed
* @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ].
*/
- protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Any): Unit = {
- val convertedAnswer = expectedAnswer match {
- case s: Seq[_] if s.isEmpty => s
- case s: Seq[_] if s.head.isInstanceOf[Product] &&
- !s.head.isInstanceOf[Seq[_]] => s.map(_.asInstanceOf[Product].productIterator.toIndexedSeq)
- case s: Seq[_] => s
- case singleItem => Seq(Seq(singleItem))
- }
-
+ protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = {
val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
- def prepareAnswer(answer: Seq[Any]): Seq[Any] = {
+ def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
// For BigDecimal type, the Scala type has a better definition of equality test (similar to
// Java's java.math.BigDecimal.compareTo).
- val converted = answer.map {
- case s: Seq[_] => s.map {
+ val converted: Seq[Row] = answer.map { s =>
+ Row.fromSeq(s.toSeq.map {
case d: java.math.BigDecimal => BigDecimal(d)
case o => o
- }
- case o => o
+ })
}
if (!isSorted) converted.sortBy(_.toString) else converted
}
@@ -82,7 +73,7 @@ class QueryTest extends PlanTest {
""".stripMargin)
}
- if (prepareAnswer(convertedAnswer) != prepareAnswer(sparkAnswer)) {
+ if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
fail(s"""
|Results do not match for query:
|${rdd.logicalPlan}
@@ -92,15 +83,19 @@ class QueryTest extends PlanTest {
|${rdd.queryExecution.executedPlan}
|== Results ==
|${sideBySide(
- s"== Correct Answer - ${convertedAnswer.size} ==" +:
- prepareAnswer(convertedAnswer).map(_.toString),
+ s"== Correct Answer - ${expectedAnswer.size} ==" +:
+ prepareAnswer(expectedAnswer).map(_.toString),
s"== Spark Answer - ${sparkAnswer.size} ==" +:
prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")}
""".stripMargin)
}
}
- def sqlTest(sqlString: String, expectedAnswer: Any)(implicit sqlContext: SQLContext): Unit = {
+ protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = {
+ checkAnswer(rdd, Seq(expectedAnswer))
+ }
+
+ def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = {
test(sqlString) {
checkAnswer(sqlContext.sql(sqlString), expectedAnswer)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 54fabc5c91..03b44ca1d6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -46,7 +46,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") {
checkAnswer(
sql("SELECT a FROM testData2 SORT BY a"),
- Seq(1, 1, 2 ,2 ,3 ,3).map(Seq(_))
+ Seq(1, 1, 2 ,2 ,3 ,3).map(Row(_))
)
}
@@ -70,13 +70,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("SPARK-3176 Added Parser of SQL ABS()") {
checkAnswer(
sql("SELECT ABS(-1.3)"),
- 1.3)
+ Row(1.3))
checkAnswer(
sql("SELECT ABS(0.0)"),
- 0.0)
+ Row(0.0))
checkAnswer(
sql("SELECT ABS(2.5)"),
- 2.5)
+ Row(2.5))
}
test("aggregation with codegen") {
@@ -89,13 +89,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("SPARK-3176 Added Parser of SQL LAST()") {
checkAnswer(
sql("SELECT LAST(n) FROM lowerCaseData"),
- 4)
+ Row(4))
}
test("SPARK-2041 column name equals tablename") {
checkAnswer(
sql("SELECT tableName FROM tableName"),
- "test")
+ Row("test"))
}
test("SQRT") {
@@ -115,40 +115,40 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("SPARK-2407 Added Parser of SQL SUBSTR()") {
checkAnswer(
sql("SELECT substr(tableName, 1, 2) FROM tableName"),
- "te")
+ Row("te"))
checkAnswer(
sql("SELECT substr(tableName, 3) FROM tableName"),
- "st")
+ Row("st"))
checkAnswer(
sql("SELECT substring(tableName, 1, 2) FROM tableName"),
- "te")
+ Row("te"))
checkAnswer(
sql("SELECT substring(tableName, 3) FROM tableName"),
- "st")
+ Row("st"))
}
test("SPARK-3173 Timestamp support in the parser") {
checkAnswer(sql(
"SELECT time FROM timestamps WHERE time=CAST('1970-01-01 00:00:00.001' AS TIMESTAMP)"),
- Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))))
+ Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")))
checkAnswer(sql(
"SELECT time FROM timestamps WHERE time='1970-01-01 00:00:00.001'"),
- Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))))
+ Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")))
checkAnswer(sql(
"SELECT time FROM timestamps WHERE '1970-01-01 00:00:00.001'=time"),
- Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))))
+ Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")))
checkAnswer(sql(
"""SELECT time FROM timestamps WHERE time<'1970-01-01 00:00:00.003'
AND time>'1970-01-01 00:00:00.001'"""),
- Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002"))))
+ Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002")))
checkAnswer(sql(
"SELECT time FROM timestamps WHERE time IN ('1970-01-01 00:00:00.001','1970-01-01 00:00:00.002')"),
- Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")),
- Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002"))))
+ Seq(Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")),
+ Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002"))))
checkAnswer(sql(
"SELECT time FROM timestamps WHERE time='123'"),
@@ -158,13 +158,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("index into array") {
checkAnswer(
sql("SELECT data, data[0], data[0] + data[1], data[0 + 1] FROM arrayData"),
- arrayData.map(d => (d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect().toSeq)
+ arrayData.map(d => Row(d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect())
}
test("left semi greater than predicate") {
checkAnswer(
sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"),
- Seq((3,1), (3,2))
+ Seq(Row(3,1), Row(3,2))
)
}
@@ -173,7 +173,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
sql(
"SELECT nestedData, nestedData[0][0], nestedData[0][0] + nestedData[0][1] FROM arrayData"),
arrayData.map(d =>
- (d.nestedData,
+ Row(d.nestedData,
d.nestedData(0)(0),
d.nestedData(0)(0) + d.nestedData(0)(1))).collect().toSeq)
}
@@ -181,13 +181,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("agg") {
checkAnswer(
sql("SELECT a, SUM(b) FROM testData2 GROUP BY a"),
- Seq((1,3),(2,3),(3,3)))
+ Seq(Row(1,3), Row(2,3), Row(3,3)))
}
test("aggregates with nulls") {
checkAnswer(
sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"),
- (1, 3, 2, 6, 3) :: Nil
+ Row(1, 3, 2, 6, 3)
)
}
@@ -200,29 +200,29 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("simple select") {
checkAnswer(
sql("SELECT value FROM testData WHERE key = 1"),
- Seq(Seq("1")))
+ Row("1"))
}
def sortTest() = {
checkAnswer(
sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"),
- Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2)))
+ Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2)))
checkAnswer(
sql("SELECT * FROM testData2 ORDER BY a ASC, b DESC"),
- Seq((1,2), (1,1), (2,2), (2,1), (3,2), (3,1)))
+ Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1)))
checkAnswer(
sql("SELECT * FROM testData2 ORDER BY a DESC, b DESC"),
- Seq((3,2), (3,1), (2,2), (2,1), (1,2), (1,1)))
+ Seq(Row(3,2), Row(3,1), Row(2,2), Row(2,1), Row(1,2), Row(1,1)))
checkAnswer(
sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"),
- Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2)))
+ Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2)))
checkAnswer(
sql("SELECT b FROM binaryData ORDER BY a ASC"),
- (1 to 5).map(Row(_)).toSeq)
+ (1 to 5).map(Row(_)))
checkAnswer(
sql("SELECT b FROM binaryData ORDER BY a DESC"),
@@ -230,19 +230,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
checkAnswer(
sql("SELECT * FROM arrayData ORDER BY data[0] ASC"),
- arrayData.collect().sortBy(_.data(0)).toSeq)
+ arrayData.collect().sortBy(_.data(0)).map(Row.fromTuple).toSeq)
checkAnswer(
sql("SELECT * FROM arrayData ORDER BY data[0] DESC"),
- arrayData.collect().sortBy(_.data(0)).reverse.toSeq)
+ arrayData.collect().sortBy(_.data(0)).reverse.map(Row.fromTuple).toSeq)
checkAnswer(
sql("SELECT * FROM mapData ORDER BY data[1] ASC"),
- mapData.collect().sortBy(_.data(1)).toSeq)
+ mapData.collect().sortBy(_.data(1)).map(Row.fromTuple).toSeq)
checkAnswer(
sql("SELECT * FROM mapData ORDER BY data[1] DESC"),
- mapData.collect().sortBy(_.data(1)).reverse.toSeq)
+ mapData.collect().sortBy(_.data(1)).reverse.map(Row.fromTuple).toSeq)
}
test("sorting") {
@@ -266,94 +266,94 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
checkAnswer(
sql("SELECT * FROM arrayData LIMIT 1"),
- arrayData.collect().take(1).toSeq)
+ arrayData.collect().take(1).map(Row.fromTuple).toSeq)
checkAnswer(
sql("SELECT * FROM mapData LIMIT 1"),
- mapData.collect().take(1).toSeq)
+ mapData.collect().take(1).map(Row.fromTuple).toSeq)
}
test("from follow multiple brackets") {
checkAnswer(sql(
"select key from ((select * from testData limit 1) union all (select * from testData limit 1)) x limit 1"),
- 1
+ Row(1)
)
checkAnswer(sql(
"select key from (select * from testData) x limit 1"),
- 1
+ Row(1)
)
checkAnswer(sql(
"select key from (select * from testData limit 1 union all select * from testData limit 1) x limit 1"),
- 1
+ Row(1)
)
}
test("average") {
checkAnswer(
sql("SELECT AVG(a) FROM testData2"),
- 2.0)
+ Row(2.0))
}
test("average overflow") {
checkAnswer(
sql("SELECT AVG(a),b FROM largeAndSmallInts group by b"),
- Seq((2147483645.0,1),(2.0,2)))
+ Seq(Row(2147483645.0,1), Row(2.0,2)))
}
test("count") {
checkAnswer(
sql("SELECT COUNT(*) FROM testData2"),
- testData2.count())
+ Row(testData2.count()))
}
test("count distinct") {
checkAnswer(
sql("SELECT COUNT(DISTINCT b) FROM testData2"),
- 2)
+ Row(2))
}
test("approximate count distinct") {
checkAnswer(
sql("SELECT APPROXIMATE COUNT(DISTINCT a) FROM testData2"),
- 3)
+ Row(3))
}
test("approximate count distinct with user provided standard deviation") {
checkAnswer(
sql("SELECT APPROXIMATE(0.04) COUNT(DISTINCT a) FROM testData2"),
- 3)
+ Row(3))
}
test("null count") {
checkAnswer(
sql("SELECT a, COUNT(b) FROM testData3 GROUP BY a"),
- Seq((1, 0), (2, 1)))
+ Seq(Row(1, 0), Row(2, 1)))
checkAnswer(
sql("SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3"),
- (2, 1, 2, 2, 1) :: Nil)
+ Row(2, 1, 2, 2, 1))
}
test("inner join where, one match per row") {
checkAnswer(
sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"),
Seq(
- (1, "A", 1, "a"),
- (2, "B", 2, "b"),
- (3, "C", 3, "c"),
- (4, "D", 4, "d")))
+ Row(1, "A", 1, "a"),
+ Row(2, "B", 2, "b"),
+ Row(3, "C", 3, "c"),
+ Row(4, "D", 4, "d")))
}
test("inner join ON, one match per row") {
checkAnswer(
sql("SELECT * FROM upperCaseData JOIN lowerCaseData ON n = N"),
Seq(
- (1, "A", 1, "a"),
- (2, "B", 2, "b"),
- (3, "C", 3, "c"),
- (4, "D", 4, "d")))
+ Row(1, "A", 1, "a"),
+ Row(2, "B", 2, "b"),
+ Row(3, "C", 3, "c"),
+ Row(4, "D", 4, "d")))
}
test("inner join, where, multiple matches") {
@@ -363,10 +363,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
| (SELECT * FROM testData2 WHERE a = 1) x JOIN
| (SELECT * FROM testData2 WHERE a = 1) y
|WHERE x.a = y.a""".stripMargin),
- (1,1,1,1) ::
- (1,1,1,2) ::
- (1,2,1,1) ::
- (1,2,1,2) :: Nil)
+ Row(1,1,1,1) ::
+ Row(1,1,1,2) ::
+ Row(1,2,1,1) ::
+ Row(1,2,1,2) :: Nil)
}
test("inner join, no matches") {
@@ -397,38 +397,38 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
| SELECT * FROM testData) y
|WHERE x.key = y.key""".stripMargin),
testData.flatMap(
- row => Seq.fill(16)((row ++ row).toSeq)).collect().toSeq)
+ row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq)
}
ignore("cartesian product join") {
checkAnswer(
testData3.join(testData3),
- (1, null, 1, null) ::
- (1, null, 2, 2) ::
- (2, 2, 1, null) ::
- (2, 2, 2, 2) :: Nil)
+ Row(1, null, 1, null) ::
+ Row(1, null, 2, 2) ::
+ Row(2, 2, 1, null) ::
+ Row(2, 2, 2, 2) :: Nil)
}
test("left outer join") {
checkAnswer(
sql("SELECT * FROM upperCaseData LEFT OUTER JOIN lowerCaseData ON n = N"),
- (1, "A", 1, "a") ::
- (2, "B", 2, "b") ::
- (3, "C", 3, "c") ::
- (4, "D", 4, "d") ::
- (5, "E", null, null) ::
- (6, "F", null, null) :: Nil)
+ Row(1, "A", 1, "a") ::
+ Row(2, "B", 2, "b") ::
+ Row(3, "C", 3, "c") ::
+ Row(4, "D", 4, "d") ::
+ Row(5, "E", null, null) ::
+ Row(6, "F", null, null) :: Nil)
}
test("right outer join") {
checkAnswer(
sql("SELECT * FROM lowerCaseData RIGHT OUTER JOIN upperCaseData ON n = N"),
- (1, "a", 1, "A") ::
- (2, "b", 2, "B") ::
- (3, "c", 3, "C") ::
- (4, "d", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
+ Row(1, "a", 1, "A") ::
+ Row(2, "b", 2, "B") ::
+ Row(3, "c", 3, "C") ::
+ Row(4, "d", 4, "D") ::
+ Row(null, null, 5, "E") ::
+ Row(null, null, 6, "F") :: Nil)
}
test("full outer join") {
@@ -440,12 +440,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
| (SELECT * FROM upperCaseData WHERE N >= 3) rightTable
| ON leftTable.N = rightTable.N
""".stripMargin),
- (1, "A", null, null) ::
- (2, "B", null, null) ::
- (3, "C", 3, "C") ::
- (4, "D", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
+ Row(1, "A", null, null) ::
+ Row(2, "B", null, null) ::
+ Row(3, "C", 3, "C") ::
+ Row (4, "D", 4, "D") ::
+ Row(null, null, 5, "E") ::
+ Row(null, null, 6, "F") :: Nil)
}
test("SPARK-3349 partitioning after limit") {
@@ -457,12 +457,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
.registerTempTable("subset2")
checkAnswer(
sql("SELECT * FROM lowerCaseData INNER JOIN subset1 ON subset1.n = lowerCaseData.n"),
- (3, "c", 3) ::
- (4, "d", 4) :: Nil)
+ Row(3, "c", 3) ::
+ Row(4, "d", 4) :: Nil)
checkAnswer(
sql("SELECT * FROM lowerCaseData INNER JOIN subset2 ON subset2.n = lowerCaseData.n"),
- (1, "a", 1) ::
- (2, "b", 2) :: Nil)
+ Row(1, "a", 1) ::
+ Row(2, "b", 2) :: Nil)
}
test("mixed-case keywords") {
@@ -474,28 +474,28 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
| (sElEcT * FROM upperCaseData whERe N >= 3) rightTable
| oN leftTable.N = rightTable.N
""".stripMargin),
- (1, "A", null, null) ::
- (2, "B", null, null) ::
- (3, "C", 3, "C") ::
- (4, "D", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
+ Row(1, "A", null, null) ::
+ Row(2, "B", null, null) ::
+ Row(3, "C", 3, "C") ::
+ Row(4, "D", 4, "D") ::
+ Row(null, null, 5, "E") ::
+ Row(null, null, 6, "F") :: Nil)
}
test("select with table name as qualifier") {
checkAnswer(
sql("SELECT testData.value FROM testData WHERE testData.key = 1"),
- Seq(Seq("1")))
+ Row("1"))
}
test("inner join ON with table name as qualifier") {
checkAnswer(
sql("SELECT * FROM upperCaseData JOIN lowerCaseData ON lowerCaseData.n = upperCaseData.N"),
Seq(
- (1, "A", 1, "a"),
- (2, "B", 2, "b"),
- (3, "C", 3, "c"),
- (4, "D", 4, "d")))
+ Row(1, "A", 1, "a"),
+ Row(2, "B", 2, "b"),
+ Row(3, "C", 3, "c"),
+ Row(4, "D", 4, "d")))
}
test("qualified select with inner join ON with table name as qualifier") {
@@ -503,72 +503,72 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
sql("SELECT upperCaseData.N, upperCaseData.L FROM upperCaseData JOIN lowerCaseData " +
"ON lowerCaseData.n = upperCaseData.N"),
Seq(
- (1, "A"),
- (2, "B"),
- (3, "C"),
- (4, "D")))
+ Row(1, "A"),
+ Row(2, "B"),
+ Row(3, "C"),
+ Row(4, "D")))
}
test("system function upper()") {
checkAnswer(
sql("SELECT n,UPPER(l) FROM lowerCaseData"),
Seq(
- (1, "A"),
- (2, "B"),
- (3, "C"),
- (4, "D")))
+ Row(1, "A"),
+ Row(2, "B"),
+ Row(3, "C"),
+ Row(4, "D")))
checkAnswer(
sql("SELECT n, UPPER(s) FROM nullStrings"),
Seq(
- (1, "ABC"),
- (2, "ABC"),
- (3, null)))
+ Row(1, "ABC"),
+ Row(2, "ABC"),
+ Row(3, null)))
}
test("system function lower()") {
checkAnswer(
sql("SELECT N,LOWER(L) FROM upperCaseData"),
Seq(
- (1, "a"),
- (2, "b"),
- (3, "c"),
- (4, "d"),
- (5, "e"),
- (6, "f")))
+ Row(1, "a"),
+ Row(2, "b"),
+ Row(3, "c"),
+ Row(4, "d"),
+ Row(5, "e"),
+ Row(6, "f")))
checkAnswer(
sql("SELECT n, LOWER(s) FROM nullStrings"),
Seq(
- (1, "abc"),
- (2, "abc"),
- (3, null)))
+ Row(1, "abc"),
+ Row(2, "abc"),
+ Row(3, null)))
}
test("UNION") {
checkAnswer(
sql("SELECT * FROM lowerCaseData UNION SELECT * FROM upperCaseData"),
- (1, "A") :: (1, "a") :: (2, "B") :: (2, "b") :: (3, "C") :: (3, "c") ::
- (4, "D") :: (4, "d") :: (5, "E") :: (6, "F") :: Nil)
+ Row(1, "A") :: Row(1, "a") :: Row(2, "B") :: Row(2, "b") :: Row(3, "C") :: Row(3, "c") ::
+ Row(4, "D") :: Row(4, "d") :: Row(5, "E") :: Row(6, "F") :: Nil)
checkAnswer(
sql("SELECT * FROM lowerCaseData UNION SELECT * FROM lowerCaseData"),
- (1, "a") :: (2, "b") :: (3, "c") :: (4, "d") :: Nil)
+ Row(1, "a") :: Row(2, "b") :: Row(3, "c") :: Row(4, "d") :: Nil)
checkAnswer(
sql("SELECT * FROM lowerCaseData UNION ALL SELECT * FROM lowerCaseData"),
- (1, "a") :: (1, "a") :: (2, "b") :: (2, "b") :: (3, "c") :: (3, "c") ::
- (4, "d") :: (4, "d") :: Nil)
+ Row(1, "a") :: Row(1, "a") :: Row(2, "b") :: Row(2, "b") :: Row(3, "c") :: Row(3, "c") ::
+ Row(4, "d") :: Row(4, "d") :: Nil)
}
test("UNION with column mismatches") {
// Column name mismatches are allowed.
checkAnswer(
sql("SELECT n,l FROM lowerCaseData UNION SELECT N as x1, L as x2 FROM upperCaseData"),
- (1, "A") :: (1, "a") :: (2, "B") :: (2, "b") :: (3, "C") :: (3, "c") ::
- (4, "D") :: (4, "d") :: (5, "E") :: (6, "F") :: Nil)
+ Row(1, "A") :: Row(1, "a") :: Row(2, "B") :: Row(2, "b") :: Row(3, "C") :: Row(3, "c") ::
+ Row(4, "D") :: Row(4, "d") :: Row(5, "E") :: Row(6, "F") :: Nil)
// Column type mismatches are not allowed, forcing a type coercion.
checkAnswer(
sql("SELECT n FROM lowerCaseData UNION SELECT L FROM upperCaseData"),
- ("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Tuple1(_)))
+ ("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Row(_)))
// Column type mismatches where a coercion is not possible, in this case between integer
// and array types, trigger a TreeNodeException.
intercept[TreeNodeException[_]] {
@@ -579,10 +579,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("EXCEPT") {
checkAnswer(
sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM upperCaseData"),
- (1, "a") ::
- (2, "b") ::
- (3, "c") ::
- (4, "d") :: Nil)
+ Row(1, "a") ::
+ Row(2, "b") ::
+ Row(3, "c") ::
+ Row(4, "d") :: Nil)
checkAnswer(
sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM lowerCaseData"), Nil)
checkAnswer(
@@ -592,10 +592,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("INTERSECT") {
checkAnswer(
sql("SELECT * FROM lowerCaseData INTERSECT SELECT * FROM lowerCaseData"),
- (1, "a") ::
- (2, "b") ::
- (3, "c") ::
- (4, "d") :: Nil)
+ Row(1, "a") ::
+ Row(2, "b") ::
+ Row(3, "c") ::
+ Row(4, "d") :: Nil)
checkAnswer(
sql("SELECT * FROM lowerCaseData INTERSECT SELECT * FROM upperCaseData"), Nil)
}
@@ -613,25 +613,25 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
sql(s"SET $testKey=$testVal")
checkAnswer(
sql("SET"),
- Seq(Seq(s"$testKey=$testVal"))
+ Row(s"$testKey=$testVal")
)
sql(s"SET ${testKey + testKey}=${testVal + testVal}")
checkAnswer(
sql("set"),
Seq(
- Seq(s"$testKey=$testVal"),
- Seq(s"${testKey + testKey}=${testVal + testVal}"))
+ Row(s"$testKey=$testVal"),
+ Row(s"${testKey + testKey}=${testVal + testVal}"))
)
// "set key"
checkAnswer(
sql(s"SET $testKey"),
- Seq(Seq(s"$testKey=$testVal"))
+ Row(s"$testKey=$testVal")
)
checkAnswer(
sql(s"SET $nonexistentKey"),
- Seq(Seq(s"$nonexistentKey=<undefined>"))
+ Row(s"$nonexistentKey=<undefined>")
)
conf.clear()
}
@@ -655,17 +655,17 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
schemaRDD1.registerTempTable("applySchema1")
checkAnswer(
sql("SELECT * FROM applySchema1"),
- (1, "A1", true, null) ::
- (2, "B2", false, null) ::
- (3, "C3", true, null) ::
- (4, "D4", true, 2147483644) :: Nil)
+ Row(1, "A1", true, null) ::
+ Row(2, "B2", false, null) ::
+ Row(3, "C3", true, null) ::
+ Row(4, "D4", true, 2147483644) :: Nil)
checkAnswer(
sql("SELECT f1, f4 FROM applySchema1"),
- (1, null) ::
- (2, null) ::
- (3, null) ::
- (4, 2147483644) :: Nil)
+ Row(1, null) ::
+ Row(2, null) ::
+ Row(3, null) ::
+ Row(4, 2147483644) :: Nil)
val schema2 = StructType(
StructField("f1", StructType(
@@ -685,17 +685,17 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
schemaRDD2.registerTempTable("applySchema2")
checkAnswer(
sql("SELECT * FROM applySchema2"),
- (Seq(1, true), Map("A1" -> null)) ::
- (Seq(2, false), Map("B2" -> null)) ::
- (Seq(3, true), Map("C3" -> null)) ::
- (Seq(4, true), Map("D4" -> 2147483644)) :: Nil)
+ Row(Row(1, true), Map("A1" -> null)) ::
+ Row(Row(2, false), Map("B2" -> null)) ::
+ Row(Row(3, true), Map("C3" -> null)) ::
+ Row(Row(4, true), Map("D4" -> 2147483644)) :: Nil)
checkAnswer(
sql("SELECT f1.f11, f2['D4'] FROM applySchema2"),
- (1, null) ::
- (2, null) ::
- (3, null) ::
- (4, 2147483644) :: Nil)
+ Row(1, null) ::
+ Row(2, null) ::
+ Row(3, null) ::
+ Row(4, 2147483644) :: Nil)
// The value of a MapType column can be a mutable map.
val rowRDD3 = unparsedStrings.map { r =>
@@ -711,26 +711,26 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
checkAnswer(
sql("SELECT f1.f11, f2['D4'] FROM applySchema3"),
- (1, null) ::
- (2, null) ::
- (3, null) ::
- (4, 2147483644) :: Nil)
+ Row(1, null) ::
+ Row(2, null) ::
+ Row(3, null) ::
+ Row(4, 2147483644) :: Nil)
}
test("SPARK-3423 BETWEEN") {
checkAnswer(
sql("SELECT key, value FROM testData WHERE key BETWEEN 5 and 7"),
- Seq((5, "5"), (6, "6"), (7, "7"))
+ Seq(Row(5, "5"), Row(6, "6"), Row(7, "7"))
)
checkAnswer(
sql("SELECT key, value FROM testData WHERE key BETWEEN 7 and 7"),
- Seq((7, "7"))
+ Row(7, "7")
)
checkAnswer(
sql("SELECT key, value FROM testData WHERE key BETWEEN 9 and 7"),
- Seq()
+ Nil
)
}
@@ -738,7 +738,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
// TODO Ensure true/false string letter casing is consistent with Hive in all cases.
checkAnswer(
sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"),
- ("true", "false") :: Nil)
+ Row("true", "false"))
}
test("metadata is propagated correctly") {
@@ -768,17 +768,20 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("SPARK-3371 Renaming a function expression with group by gives error") {
udf.register("len", (s: String) => s.length)
checkAnswer(
- sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1)
+ sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"),
+ Row(1))
}
test("SPARK-3813 CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END") {
checkAnswer(
- sql("SELECT CASE key WHEN 1 THEN 1 ELSE 0 END FROM testData WHERE key = 1 group by key"), 1)
+ sql("SELECT CASE key WHEN 1 THEN 1 ELSE 0 END FROM testData WHERE key = 1 group by key"),
+ Row(1))
}
test("SPARK-3813 CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END") {
checkAnswer(
- sql("SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), 1)
+ sql("SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"),
+ Row(1))
}
test("throw errors for non-aggregate attributes with aggregation") {
@@ -808,130 +811,131 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("Test to check we can use Long.MinValue") {
checkAnswer(
- sql(s"SELECT ${Long.MinValue} FROM testData ORDER BY key LIMIT 1"), Long.MinValue
+ sql(s"SELECT ${Long.MinValue} FROM testData ORDER BY key LIMIT 1"), Row(Long.MinValue)
)
checkAnswer(
- sql(s"SELECT key FROM testData WHERE key > ${Long.MinValue}"), (1 to 100).map(Row(_)).toSeq
+ sql(s"SELECT key FROM testData WHERE key > ${Long.MinValue}"),
+ (1 to 100).map(Row(_)).toSeq
)
}
test("Floating point number format") {
checkAnswer(
- sql("SELECT 0.3"), 0.3
+ sql("SELECT 0.3"), Row(0.3)
)
checkAnswer(
- sql("SELECT -0.8"), -0.8
+ sql("SELECT -0.8"), Row(-0.8)
)
checkAnswer(
- sql("SELECT .5"), 0.5
+ sql("SELECT .5"), Row(0.5)
)
checkAnswer(
- sql("SELECT -.18"), -0.18
+ sql("SELECT -.18"), Row(-0.18)
)
}
test("Auto cast integer type") {
checkAnswer(
- sql(s"SELECT ${Int.MaxValue + 1L}"), Int.MaxValue + 1L
+ sql(s"SELECT ${Int.MaxValue + 1L}"), Row(Int.MaxValue + 1L)
)
checkAnswer(
- sql(s"SELECT ${Int.MinValue - 1L}"), Int.MinValue - 1L
+ sql(s"SELECT ${Int.MinValue - 1L}"), Row(Int.MinValue - 1L)
)
checkAnswer(
- sql("SELECT 9223372036854775808"), new java.math.BigDecimal("9223372036854775808")
+ sql("SELECT 9223372036854775808"), Row(new java.math.BigDecimal("9223372036854775808"))
)
checkAnswer(
- sql("SELECT -9223372036854775809"), new java.math.BigDecimal("-9223372036854775809")
+ sql("SELECT -9223372036854775809"), Row(new java.math.BigDecimal("-9223372036854775809"))
)
}
test("Test to check we can apply sign to expression") {
checkAnswer(
- sql("SELECT -100"), -100
+ sql("SELECT -100"), Row(-100)
)
checkAnswer(
- sql("SELECT +230"), 230
+ sql("SELECT +230"), Row(230)
)
checkAnswer(
- sql("SELECT -5.2"), -5.2
+ sql("SELECT -5.2"), Row(-5.2)
)
checkAnswer(
- sql("SELECT +6.8"), 6.8
+ sql("SELECT +6.8"), Row(6.8)
)
checkAnswer(
- sql("SELECT -key FROM testData WHERE key = 2"), -2
+ sql("SELECT -key FROM testData WHERE key = 2"), Row(-2)
)
checkAnswer(
- sql("SELECT +key FROM testData WHERE key = 3"), 3
+ sql("SELECT +key FROM testData WHERE key = 3"), Row(3)
)
checkAnswer(
- sql("SELECT -(key + 1) FROM testData WHERE key = 1"), -2
+ sql("SELECT -(key + 1) FROM testData WHERE key = 1"), Row(-2)
)
checkAnswer(
- sql("SELECT - key + 1 FROM testData WHERE key = 10"), -9
+ sql("SELECT - key + 1 FROM testData WHERE key = 10"), Row(-9)
)
checkAnswer(
- sql("SELECT +(key + 5) FROM testData WHERE key = 5"), 10
+ sql("SELECT +(key + 5) FROM testData WHERE key = 5"), Row(10)
)
checkAnswer(
- sql("SELECT -MAX(key) FROM testData"), -100
+ sql("SELECT -MAX(key) FROM testData"), Row(-100)
)
checkAnswer(
- sql("SELECT +MAX(key) FROM testData"), 100
+ sql("SELECT +MAX(key) FROM testData"), Row(100)
)
checkAnswer(
- sql("SELECT - (-10)"), 10
+ sql("SELECT - (-10)"), Row(10)
)
checkAnswer(
- sql("SELECT + (-key) FROM testData WHERE key = 32"), -32
+ sql("SELECT + (-key) FROM testData WHERE key = 32"), Row(-32)
)
checkAnswer(
- sql("SELECT - (+Max(key)) FROM testData"), -100
+ sql("SELECT - (+Max(key)) FROM testData"), Row(-100)
)
checkAnswer(
- sql("SELECT - - 3"), 3
+ sql("SELECT - - 3"), Row(3)
)
checkAnswer(
- sql("SELECT - + 20"), -20
+ sql("SELECT - + 20"), Row(-20)
)
checkAnswer(
- sql("SELEcT - + 45"), -45
+ sql("SELEcT - + 45"), Row(-45)
)
checkAnswer(
- sql("SELECT + + 100"), 100
+ sql("SELECT + + 100"), Row(100)
)
checkAnswer(
- sql("SELECT - - Max(key) FROM testData"), 100
+ sql("SELECT - - Max(key) FROM testData"), Row(100)
)
checkAnswer(
- sql("SELECT + - key FROM testData WHERE key = 33"), -33
+ sql("SELECT + - key FROM testData WHERE key = 33"), Row(-33)
)
}
@@ -943,7 +947,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
|JOIN testData b ON a.key = b.key
|JOIN testData c ON a.key = c.key
""".stripMargin),
- (1 to 100).map(i => Seq(i, i, i)))
+ (1 to 100).map(i => Row(i, i, i)))
}
test("SPARK-3483 Special chars in column names") {
@@ -953,19 +957,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
}
test("SPARK-3814 Support Bitwise & operator") {
- checkAnswer(sql("SELECT key&1 FROM testData WHERE key = 1 "), 1)
+ checkAnswer(sql("SELECT key&1 FROM testData WHERE key = 1 "), Row(1))
}
test("SPARK-3814 Support Bitwise | operator") {
- checkAnswer(sql("SELECT key|0 FROM testData WHERE key = 1 "), 1)
+ checkAnswer(sql("SELECT key|0 FROM testData WHERE key = 1 "), Row(1))
}
test("SPARK-3814 Support Bitwise ^ operator") {
- checkAnswer(sql("SELECT key^0 FROM testData WHERE key = 1 "), 1)
+ checkAnswer(sql("SELECT key^0 FROM testData WHERE key = 1 "), Row(1))
}
test("SPARK-3814 Support Bitwise ~ operator") {
- checkAnswer(sql("SELECT ~key FROM testData WHERE key = 1 "), -2)
+ checkAnswer(sql("SELECT ~key FROM testData WHERE key = 1 "), Row(-2))
}
test("SPARK-4120 Join of multiple tables does not work in SparkSQL") {
@@ -975,40 +979,40 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
|FROM testData a,testData b,testData c
|where a.key = b.key and a.key = c.key
""".stripMargin),
- (1 to 100).map(i => Seq(i, i, i)))
+ (1 to 100).map(i => Row(i, i, i)))
}
test("SPARK-4154 Query does not work if it has 'not between' in Spark SQL and HQL") {
checkAnswer(sql("SELECT key FROM testData WHERE key not between 0 and 10 order by key"),
- (11 to 100).map(i => Seq(i)))
+ (11 to 100).map(i => Row(i)))
}
test("SPARK-4207 Query which has syntax like 'not like' is not working in Spark SQL") {
checkAnswer(sql("SELECT key FROM testData WHERE value not like '100%' order by key"),
- (1 to 99).map(i => Seq(i)))
+ (1 to 99).map(i => Row(i)))
}
test("SPARK-4322 Grouping field with struct field as sub expression") {
jsonRDD(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)).registerTempTable("data")
- checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), 1)
+ checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1))
dropTempTable("data")
jsonRDD(sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data")
- checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), 2)
+ checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2))
dropTempTable("data")
}
test("SPARK-4432 Fix attribute reference resolution error when using ORDER BY") {
checkAnswer(
sql("SELECT a + b FROM testData2 ORDER BY a"),
- Seq(2, 3, 3 ,4 ,4 ,5).map(Seq(_))
+ Seq(2, 3, 3 ,4 ,4 ,5).map(Row(_))
)
}
test("oder by asc by default when not specify ascending and descending") {
checkAnswer(
sql("SELECT a, b FROM testData2 ORDER BY a desc, b"),
- Seq((3, 1), (3, 2), (2, 1), (2,2), (1, 1), (1, 2))
+ Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2,2), Row(1, 1), Row(1, 2))
)
}
@@ -1021,13 +1025,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
rdd2.registerTempTable("nulldata2")
checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " +
"nulldata2 on nulldata1.value <=> nulldata2.value"),
- (1 to 2).map(i => Seq(i)))
+ (1 to 2).map(i => Row(i)))
}
test("Multi-column COUNT(DISTINCT ...)") {
val data = TestData(1,"val_1") :: TestData(2,"val_2") :: Nil
val rdd = sparkContext.parallelize((0 to 1).map(i => data(i)))
rdd.registerTempTable("distinctData")
- checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), 2)
+ checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
index ee381da491..a015884bae 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
@@ -82,7 +82,7 @@ class ScalaReflectionRelationSuite extends FunSuite {
rdd.registerTempTable("reflectData")
assert(sql("SELECT * FROM reflectData").collect().head ===
- Seq("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
+ Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3)))
}
@@ -91,7 +91,7 @@ class ScalaReflectionRelationSuite extends FunSuite {
val rdd = sparkContext.parallelize(data :: Nil)
rdd.registerTempTable("reflectNullData")
- assert(sql("SELECT * FROM reflectNullData").collect().head === Seq.fill(7)(null))
+ assert(sql("SELECT * FROM reflectNullData").collect().head === Row.fromSeq(Seq.fill(7)(null)))
}
test("query case class RDD with Nones") {
@@ -99,7 +99,7 @@ class ScalaReflectionRelationSuite extends FunSuite {
val rdd = sparkContext.parallelize(data :: Nil)
rdd.registerTempTable("reflectOptionalData")
- assert(sql("SELECT * FROM reflectOptionalData").collect().head === Seq.fill(7)(null))
+ assert(sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null)))
}
// Equality is broken for Arrays, so we test that separately.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
index 9be0b38e68..be2b34de07 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
@@ -42,8 +42,8 @@ class ColumnStatsSuite extends FunSuite {
test(s"$columnStatsName: empty") {
val columnStats = columnStatsClass.newInstance()
- columnStats.collectedStatistics.zip(initialStatistics).foreach { case (actual, expected) =>
- assert(actual === expected)
+ columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach {
+ case (actual, expected) => assert(actual === expected)
}
}
@@ -54,7 +54,7 @@ class ColumnStatsSuite extends FunSuite {
val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
rows.foreach(columnStats.gatherStats(_, 0))
- val values = rows.take(10).map(_.head.asInstanceOf[T#JvmType])
+ val values = rows.take(10).map(_(0).asInstanceOf[T#JvmType])
val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#JvmType]]
val stats = columnStats.collectedStatistics
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
index d94729ba92..e61f3c3963 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
@@ -49,7 +49,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
checkAnswer(scan, testData.collect().map {
case Row(key: Int, value: String) => value -> key
- }.toSeq)
+ }.map(Row.fromTuple))
}
test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") {
@@ -63,49 +63,49 @@ class InMemoryColumnarQuerySuite extends QueryTest {
test("SPARK-1678 regression: compression must not lose repeated values") {
checkAnswer(
sql("SELECT * FROM repeatedData"),
- repeatedData.collect().toSeq)
+ repeatedData.collect().toSeq.map(Row.fromTuple))
cacheTable("repeatedData")
checkAnswer(
sql("SELECT * FROM repeatedData"),
- repeatedData.collect().toSeq)
+ repeatedData.collect().toSeq.map(Row.fromTuple))
}
test("with null values") {
checkAnswer(
sql("SELECT * FROM nullableRepeatedData"),
- nullableRepeatedData.collect().toSeq)
+ nullableRepeatedData.collect().toSeq.map(Row.fromTuple))
cacheTable("nullableRepeatedData")
checkAnswer(
sql("SELECT * FROM nullableRepeatedData"),
- nullableRepeatedData.collect().toSeq)
+ nullableRepeatedData.collect().toSeq.map(Row.fromTuple))
}
test("SPARK-2729 regression: timestamp data type") {
checkAnswer(
sql("SELECT time FROM timestamps"),
- timestamps.collect().toSeq)
+ timestamps.collect().toSeq.map(Row.fromTuple))
cacheTable("timestamps")
checkAnswer(
sql("SELECT time FROM timestamps"),
- timestamps.collect().toSeq)
+ timestamps.collect().toSeq.map(Row.fromTuple))
}
test("SPARK-3320 regression: batched column buffer building should work with empty partitions") {
checkAnswer(
sql("SELECT * FROM withEmptyParts"),
- withEmptyParts.collect().toSeq)
+ withEmptyParts.collect().toSeq.map(Row.fromTuple))
cacheTable("withEmptyParts")
checkAnswer(
sql("SELECT * FROM withEmptyParts"),
- withEmptyParts.collect().toSeq)
+ withEmptyParts.collect().toSeq.map(Row.fromTuple))
}
test("SPARK-4182 Caching complex types") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
index 592cafbbdc..c3a3f8ddc3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
@@ -108,7 +108,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be
val queryExecution = schemaRdd.queryExecution
assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") {
- schemaRdd.collect().map(_.head).toArray
+ schemaRdd.collect().map(_(0)).toArray
}
val (readPartitions, readBatches) = schemaRdd.queryExecution.executedPlan.collect {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
index d9e488e0ff..8b518f0941 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
@@ -34,7 +34,7 @@ class BooleanBitSetSuite extends FunSuite {
val builder = TestCompressibleColumnBuilder(new NoopColumnStats, BOOLEAN, BooleanBitSet)
val rows = Seq.fill[Row](count)(makeRandomRow(BOOLEAN))
- val values = rows.map(_.head)
+ val values = rows.map(_(0))
rows.foreach(builder.appendFrom(_, 0))
val buffer = builder.build()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala
index 2cab5e0c44..272c0d4cb2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala
@@ -59,7 +59,7 @@ class TgfSuite extends QueryTest {
checkAnswer(
inputData.generate(ExampleTGF()),
Seq(
- "michael is 29 years old" :: Nil,
- "Next year, michael will be 30 years old" :: Nil))
+ Row("michael is 29 years old"),
+ Row("Next year, michael will be 30 years old")))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index 2bc9aede32..94d14acccb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -229,13 +229,13 @@ class JsonSuite extends QueryTest {
checkAnswer(
sql("select * from jsonTable"),
- (new java.math.BigDecimal("92233720368547758070"),
- true,
- 1.7976931348623157E308,
- 10,
- 21474836470L,
- null,
- "this is a simple string.") :: Nil
+ Row(new java.math.BigDecimal("92233720368547758070"),
+ true,
+ 1.7976931348623157E308,
+ 10,
+ 21474836470L,
+ null,
+ "this is a simple string.")
)
}
@@ -271,48 +271,49 @@ class JsonSuite extends QueryTest {
// Access elements of a primitive array.
checkAnswer(
sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from jsonTable"),
- ("str1", "str2", null) :: Nil
+ Row("str1", "str2", null)
)
// Access an array of null values.
checkAnswer(
sql("select arrayOfNull from jsonTable"),
- Seq(Seq(null, null, null, null)) :: Nil
+ Row(Seq(null, null, null, null))
)
// Access elements of a BigInteger array (we use DecimalType internally).
checkAnswer(
sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from jsonTable"),
- (new java.math.BigDecimal("922337203685477580700"),
- new java.math.BigDecimal("-922337203685477580800"), null) :: Nil
+ Row(new java.math.BigDecimal("922337203685477580700"),
+ new java.math.BigDecimal("-922337203685477580800"), null)
)
// Access elements of an array of arrays.
checkAnswer(
sql("select arrayOfArray1[0], arrayOfArray1[1] from jsonTable"),
- (Seq("1", "2", "3"), Seq("str1", "str2")) :: Nil
+ Row(Seq("1", "2", "3"), Seq("str1", "str2"))
)
// Access elements of an array of arrays.
checkAnswer(
sql("select arrayOfArray2[0], arrayOfArray2[1] from jsonTable"),
- (Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) :: Nil
+ Row(Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1))
)
// Access elements of an array inside a filed with the type of ArrayType(ArrayType).
checkAnswer(
sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from jsonTable"),
- ("str2", 2.1) :: Nil
+ Row("str2", 2.1)
)
// Access elements of an array of structs.
checkAnswer(
sql("select arrayOfStruct[0], arrayOfStruct[1], arrayOfStruct[2], arrayOfStruct[3] " +
"from jsonTable"),
- (true :: "str1" :: null :: Nil,
- false :: null :: null :: Nil,
- null :: null :: null :: Nil,
- null) :: Nil
+ Row(
+ Row(true, "str1", null),
+ Row(false, null, null),
+ Row(null, null, null),
+ null)
)
// Access a struct and fields inside of it.
@@ -327,13 +328,13 @@ class JsonSuite extends QueryTest {
// Access an array field of a struct.
checkAnswer(
sql("select structWithArrayFields.field1, structWithArrayFields.field2 from jsonTable"),
- (Seq(4, 5, 6), Seq("str1", "str2")) :: Nil
+ Row(Seq(4, 5, 6), Seq("str1", "str2"))
)
// Access elements of an array field of a struct.
checkAnswer(
sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from jsonTable"),
- (5, null) :: Nil
+ Row(5, null)
)
}
@@ -344,14 +345,14 @@ class JsonSuite extends QueryTest {
// Right now, "field1" and "field2" are treated as aliases. We should fix it.
checkAnswer(
sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"),
- (true, "str1") :: Nil
+ Row(true, "str1")
)
// Right now, the analyzer cannot resolve arrayOfStruct.field1 and arrayOfStruct.field2.
// Getting all values of a specific field from an array of structs.
checkAnswer(
sql("select arrayOfStruct.field1, arrayOfStruct.field2 from jsonTable"),
- (Seq(true, false), Seq("str1", null)) :: Nil
+ Row(Seq(true, false), Seq("str1", null))
)
}
@@ -372,57 +373,57 @@ class JsonSuite extends QueryTest {
checkAnswer(
sql("select * from jsonTable"),
- ("true", 11L, null, 1.1, "13.1", "str1") ::
- ("12", null, new java.math.BigDecimal("21474836470.9"), null, null, "true") ::
- ("false", 21474836470L, new java.math.BigDecimal("92233720368547758070"), 100, "str1", "false") ::
- (null, 21474836570L, new java.math.BigDecimal("1.1"), 21474836470L, "92233720368547758070", null) :: Nil
+ Row("true", 11L, null, 1.1, "13.1", "str1") ::
+ Row("12", null, new java.math.BigDecimal("21474836470.9"), null, null, "true") ::
+ Row("false", 21474836470L, new java.math.BigDecimal("92233720368547758070"), 100, "str1", "false") ::
+ Row(null, 21474836570L, new java.math.BigDecimal("1.1"), 21474836470L, "92233720368547758070", null) :: Nil
)
// Number and Boolean conflict: resolve the type as number in this query.
checkAnswer(
sql("select num_bool - 10 from jsonTable where num_bool > 11"),
- 2
+ Row(2)
)
// Widening to LongType
checkAnswer(
sql("select num_num_1 - 100 from jsonTable where num_num_1 > 11"),
- Seq(21474836370L) :: Seq(21474836470L) :: Nil
+ Row(21474836370L) :: Row(21474836470L) :: Nil
)
checkAnswer(
sql("select num_num_1 - 100 from jsonTable where num_num_1 > 10"),
- Seq(-89) :: Seq(21474836370L) :: Seq(21474836470L) :: Nil
+ Row(-89) :: Row(21474836370L) :: Row(21474836470L) :: Nil
)
// Widening to DecimalType
checkAnswer(
sql("select num_num_2 + 1.2 from jsonTable where num_num_2 > 1.1"),
- Seq(new java.math.BigDecimal("21474836472.1")) :: Seq(new java.math.BigDecimal("92233720368547758071.2")) :: Nil
+ Row(new java.math.BigDecimal("21474836472.1")) :: Row(new java.math.BigDecimal("92233720368547758071.2")) :: Nil
)
// Widening to DoubleType
checkAnswer(
sql("select num_num_3 + 1.2 from jsonTable where num_num_3 > 1.1"),
- Seq(101.2) :: Seq(21474836471.2) :: Nil
+ Row(101.2) :: Row(21474836471.2) :: Nil
)
// Number and String conflict: resolve the type as number in this query.
checkAnswer(
sql("select num_str + 1.2 from jsonTable where num_str > 14"),
- 92233720368547758071.2
+ Row(92233720368547758071.2)
)
// Number and String conflict: resolve the type as number in this query.
checkAnswer(
sql("select num_str + 1.2 from jsonTable where num_str > 92233720368547758060"),
- new java.math.BigDecimal("92233720368547758061.2").doubleValue
+ Row(new java.math.BigDecimal("92233720368547758061.2").doubleValue)
)
// String and Boolean conflict: resolve the type as string.
checkAnswer(
sql("select * from jsonTable where str_bool = 'str1'"),
- ("true", 11L, null, 1.1, "13.1", "str1") :: Nil
+ Row("true", 11L, null, 1.1, "13.1", "str1")
)
}
@@ -434,24 +435,24 @@ class JsonSuite extends QueryTest {
// Number and Boolean conflict: resolve the type as boolean in this query.
checkAnswer(
sql("select num_bool from jsonTable where NOT num_bool"),
- false
+ Row(false)
)
checkAnswer(
sql("select str_bool from jsonTable where NOT str_bool"),
- false
+ Row(false)
)
// Right now, the analyzer does not know that num_bool should be treated as a boolean.
// Number and Boolean conflict: resolve the type as boolean in this query.
checkAnswer(
sql("select num_bool from jsonTable where num_bool"),
- true
+ Row(true)
)
checkAnswer(
sql("select str_bool from jsonTable where str_bool"),
- false
+ Row(false)
)
// The plan of the following DSL is
@@ -464,7 +465,7 @@ class JsonSuite extends QueryTest {
jsonSchemaRDD.
where('num_str > BigDecimal("92233720368547758060")).
select('num_str + 1.2 as Symbol("num")),
- new java.math.BigDecimal("92233720368547758061.2")
+ Row(new java.math.BigDecimal("92233720368547758061.2"))
)
// The following test will fail. The type of num_str is StringType.
@@ -475,7 +476,7 @@ class JsonSuite extends QueryTest {
// Number and String conflict: resolve the type as number in this query.
checkAnswer(
sql("select num_str + 1.2 from jsonTable where num_str > 13"),
- Seq(14.3) :: Seq(92233720368547758071.2) :: Nil
+ Row(14.3) :: Row(92233720368547758071.2) :: Nil
)
}
@@ -496,10 +497,10 @@ class JsonSuite extends QueryTest {
checkAnswer(
sql("select * from jsonTable"),
- (Seq(), "11", "[1,2,3]", Seq(null), "[]") ::
- (null, """{"field":false}""", null, null, "{}") ::
- (Seq(4, 5, 6), null, "str", Seq(null), "[7,8,9]") ::
- (Seq(7), "{}","[str1,str2,33]", Seq("str"), """{"field":true}""") :: Nil
+ Row(Seq(), "11", "[1,2,3]", Row(null), "[]") ::
+ Row(null, """{"field":false}""", null, null, "{}") ::
+ Row(Seq(4, 5, 6), null, "str", Row(null), "[7,8,9]") ::
+ Row(Seq(7), "{}","[str1,str2,33]", Row("str"), """{"field":true}""") :: Nil
)
}
@@ -518,16 +519,16 @@ class JsonSuite extends QueryTest {
checkAnswer(
sql("select * from jsonTable"),
- Seq(Seq("1", "1.1", "true", null, "[]", "{}", "[2,3,4]",
- """{"field":"str"}"""), Seq(Seq(214748364700L), Seq(1)), null) ::
- Seq(null, null, Seq("""{"field":"str"}""", """{"field":1}""")) ::
- Seq(null, null, Seq("1", "2", "3")) :: Nil
+ Row(Seq("1", "1.1", "true", null, "[]", "{}", "[2,3,4]",
+ """{"field":"str"}"""), Seq(Row(214748364700L), Row(1)), null) ::
+ Row(null, null, Seq("""{"field":"str"}""", """{"field":1}""")) ::
+ Row(null, null, Seq("1", "2", "3")) :: Nil
)
// Treat an element as a number.
checkAnswer(
sql("select array1[0] + 1 from jsonTable where array1 is not null"),
- 2
+ Row(2)
)
}
@@ -568,13 +569,13 @@ class JsonSuite extends QueryTest {
checkAnswer(
sql("select * from jsonTable"),
- (new java.math.BigDecimal("92233720368547758070"),
+ Row(new java.math.BigDecimal("92233720368547758070"),
true,
1.7976931348623157E308,
10,
21474836470L,
null,
- "this is a simple string.") :: Nil
+ "this is a simple string.")
)
}
@@ -594,13 +595,13 @@ class JsonSuite extends QueryTest {
checkAnswer(
sql("select * from jsonTableSQL"),
- (new java.math.BigDecimal("92233720368547758070"),
+ Row(new java.math.BigDecimal("92233720368547758070"),
true,
1.7976931348623157E308,
10,
21474836470L,
null,
- "this is a simple string.") :: Nil
+ "this is a simple string.")
)
}
@@ -626,13 +627,13 @@ class JsonSuite extends QueryTest {
checkAnswer(
sql("select * from jsonTable1"),
- (new java.math.BigDecimal("92233720368547758070"),
+ Row(new java.math.BigDecimal("92233720368547758070"),
true,
1.7976931348623157E308,
10,
21474836470L,
null,
- "this is a simple string.") :: Nil
+ "this is a simple string.")
)
val jsonSchemaRDD2 = jsonRDD(primitiveFieldAndType, schema)
@@ -643,13 +644,13 @@ class JsonSuite extends QueryTest {
checkAnswer(
sql("select * from jsonTable2"),
- (new java.math.BigDecimal("92233720368547758070"),
+ Row(new java.math.BigDecimal("92233720368547758070"),
true,
1.7976931348623157E308,
10,
21474836470L,
null,
- "this is a simple string.") :: Nil
+ "this is a simple string.")
)
}
@@ -659,7 +660,7 @@ class JsonSuite extends QueryTest {
checkAnswer(
sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"),
- (true, "str1") :: Nil
+ Row(true, "str1")
)
checkAnswer(
sql(
@@ -667,7 +668,7 @@ class JsonSuite extends QueryTest {
|select complexArrayOfStruct[0].field1[1].inner2[0], complexArrayOfStruct[1].field2[0][1]
|from jsonTable
""".stripMargin),
- ("str2", 6) :: Nil
+ Row("str2", 6)
)
}
@@ -681,7 +682,7 @@ class JsonSuite extends QueryTest {
|select arrayOfArray1[0][0][0], arrayOfArray1[1][0][1], arrayOfArray1[1][1][0]
|from jsonTable
""".stripMargin),
- (5, 7, 8) :: Nil
+ Row(5, 7, 8)
)
checkAnswer(
sql(
@@ -690,7 +691,7 @@ class JsonSuite extends QueryTest {
|arrayOfArray2[1][1][1].inner2[0], arrayOfArray2[2][0][0].inner3[0][0].inner4
|from jsonTable
""".stripMargin),
- ("str1", Nil, "str4", 2) :: Nil
+ Row("str1", Nil, "str4", 2)
)
}
@@ -704,10 +705,10 @@ class JsonSuite extends QueryTest {
|select a, b, c
|from jsonTable
""".stripMargin),
- ("str_a_1", null, null) ::
- ("str_a_2", null, null) ::
- (null, "str_b_3", null) ::
- ("str_a_4", "str_b_4", "str_c_4") :: Nil
+ Row("str_a_1", null, null) ::
+ Row("str_a_2", null, null) ::
+ Row(null, "str_b_3", null) ::
+ Row("str_a_4", "str_b_4", "str_c_4") :: Nil
)
}
@@ -734,12 +735,12 @@ class JsonSuite extends QueryTest {
|SELECT a, b, c, _unparsed
|FROM jsonTable
""".stripMargin),
- (null, null, null, "{") ::
- (null, null, null, "") ::
- (null, null, null, """{"a":1, b:2}""") ::
- (null, null, null, """{"a":{, b:3}""") ::
- ("str_a_4", "str_b_4", "str_c_4", null) ::
- (null, null, null, "]") :: Nil
+ Row(null, null, null, "{") ::
+ Row(null, null, null, "") ::
+ Row(null, null, null, """{"a":1, b:2}""") ::
+ Row(null, null, null, """{"a":{, b:3}""") ::
+ Row("str_a_4", "str_b_4", "str_c_4", null) ::
+ Row(null, null, null, "]") :: Nil
)
checkAnswer(
@@ -749,7 +750,7 @@ class JsonSuite extends QueryTest {
|FROM jsonTable
|WHERE _unparsed IS NULL
""".stripMargin),
- ("str_a_4", "str_b_4", "str_c_4") :: Nil
+ Row("str_a_4", "str_b_4", "str_c_4")
)
checkAnswer(
@@ -759,11 +760,11 @@ class JsonSuite extends QueryTest {
|FROM jsonTable
|WHERE _unparsed IS NOT NULL
""".stripMargin),
- Seq("{") ::
- Seq("") ::
- Seq("""{"a":1, b:2}""") ::
- Seq("""{"a":{, b:3}""") ::
- Seq("]") :: Nil
+ Row("{") ::
+ Row("") ::
+ Row("""{"a":1, b:2}""") ::
+ Row("""{"a":{, b:3}""") ::
+ Row("]") :: Nil
)
TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord)
@@ -793,10 +794,10 @@ class JsonSuite extends QueryTest {
|SELECT field1, field2, field3, field4
|FROM jsonTable
""".stripMargin),
- Seq(Seq(Seq(null), Seq(Seq(Seq("Test")))), null, null, null) ::
- Seq(null, Seq(null, Seq(Seq(1))), null, null) ::
- Seq(null, null, Seq(Seq(null), Seq(Seq("2"))), null) ::
- Seq(null, null, null, Seq(Seq(null, Seq(1, 2, 3)))) :: Nil
+ Row(Seq(Seq(null), Seq(Seq(Seq("Test")))), null, null, null) ::
+ Row(null, Seq(null, Seq(Row(1))), null, null) ::
+ Row(null, null, Seq(Seq(null), Seq(Row("2"))), null) ::
+ Row(null, null, null, Seq(Seq(null, Seq(1, 2, 3)))) :: Nil
)
}
@@ -851,12 +852,12 @@ class JsonSuite extends QueryTest {
primTable.registerTempTable("primativeTable")
checkAnswer(
sql("select * from primativeTable"),
- (new java.math.BigDecimal("92233720368547758070"),
+ Row(new java.math.BigDecimal("92233720368547758070"),
true,
1.7976931348623157E308,
10,
21474836470L,
- "this is a simple string.") :: Nil
+ "this is a simple string.")
)
val complexJsonSchemaRDD = jsonRDD(complexFieldAndType1)
@@ -865,38 +866,38 @@ class JsonSuite extends QueryTest {
// Access elements of a primitive array.
checkAnswer(
sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from complexTable"),
- ("str1", "str2", null) :: Nil
+ Row("str1", "str2", null)
)
// Access an array of null values.
checkAnswer(
sql("select arrayOfNull from complexTable"),
- Seq(Seq(null, null, null, null)) :: Nil
+ Row(Seq(null, null, null, null))
)
// Access elements of a BigInteger array (we use DecimalType internally).
checkAnswer(
sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from complexTable"),
- (new java.math.BigDecimal("922337203685477580700"),
- new java.math.BigDecimal("-922337203685477580800"), null) :: Nil
+ Row(new java.math.BigDecimal("922337203685477580700"),
+ new java.math.BigDecimal("-922337203685477580800"), null)
)
// Access elements of an array of arrays.
checkAnswer(
sql("select arrayOfArray1[0], arrayOfArray1[1] from complexTable"),
- (Seq("1", "2", "3"), Seq("str1", "str2")) :: Nil
+ Row(Seq("1", "2", "3"), Seq("str1", "str2"))
)
// Access elements of an array of arrays.
checkAnswer(
sql("select arrayOfArray2[0], arrayOfArray2[1] from complexTable"),
- (Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) :: Nil
+ Row(Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1))
)
// Access elements of an array inside a filed with the type of ArrayType(ArrayType).
checkAnswer(
sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from complexTable"),
- ("str2", 2.1) :: Nil
+ Row("str2", 2.1)
)
// Access a struct and fields inside of it.
@@ -911,13 +912,13 @@ class JsonSuite extends QueryTest {
// Access an array field of a struct.
checkAnswer(
sql("select structWithArrayFields.field1, structWithArrayFields.field2 from complexTable"),
- (Seq(4, 5, 6), Seq("str1", "str2")) :: Nil
+ Row(Seq(4, 5, 6), Seq("str1", "str2"))
)
// Access elements of an array field of a struct.
checkAnswer(
sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from complexTable"),
- (5, null) :: Nil
+ Row(5, null)
)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
index 4c3a04506c..4ad8c47200 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
@@ -46,7 +46,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
predicate: Predicate,
filterClass: Class[_ <: FilterPredicate],
checker: (SchemaRDD, Any) => Unit,
- expectedResult: => Any): Unit = {
+ expectedResult: Any): Unit = {
withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") {
val query = rdd.select(output.map(_.attr): _*).where(predicate)
@@ -65,11 +65,20 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
}
}
- private def checkFilterPushdown
+ private def checkFilterPushdown1
(rdd: SchemaRDD, output: Symbol*)
(predicate: Predicate, filterClass: Class[_ <: FilterPredicate])
- (expectedResult: => Any): Unit = {
- checkFilterPushdown(rdd, output, predicate, filterClass, checkAnswer _, expectedResult)
+ (expectedResult: => Seq[Row]): Unit = {
+ checkFilterPushdown(rdd, output, predicate, filterClass,
+ (query, expected) => checkAnswer(query, expected.asInstanceOf[Seq[Row]]), expectedResult)
+ }
+
+ private def checkFilterPushdown
+ (rdd: SchemaRDD, output: Symbol*)
+ (predicate: Predicate, filterClass: Class[_ <: FilterPredicate])
+ (expectedResult: Int): Unit = {
+ checkFilterPushdown(rdd, output, predicate, filterClass,
+ (query, expected) => checkAnswer(query, expected.asInstanceOf[Seq[Row]]), Seq(Row(expectedResult)))
}
def checkBinaryFilterPushdown
@@ -89,27 +98,25 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
test("filter pushdown - boolean") {
withParquetRDD((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { rdd =>
- checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Boolean]])(Seq.empty[Row])
- checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Boolean]]) {
+ checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Boolean]])(Seq.empty[Row])
+ checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Boolean]]) {
Seq(Row(true), Row(false))
}
- checkFilterPushdown(rdd, '_1)('_1 === true, classOf[Eq[java.lang.Boolean]])(true)
- checkFilterPushdown(rdd, '_1)('_1 !== true, classOf[Operators.NotEq[java.lang.Boolean]]) {
- false
- }
+ checkFilterPushdown1(rdd, '_1)('_1 === true, classOf[Eq[java.lang.Boolean]])(Seq(Row(true)))
+ checkFilterPushdown1(rdd, '_1)('_1 !== true, classOf[Operators.NotEq[java.lang.Boolean]])(Seq(Row(false)))
}
}
test("filter pushdown - integer") {
withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { rdd =>
- checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[Integer]])(Seq.empty[Row])
- checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[Integer]]) {
+ checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[Integer]])(Seq.empty[Row])
+ checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[Integer]]) {
(1 to 4).map(Row.apply(_))
}
checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[Integer]])(1)
- checkFilterPushdown(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[Integer]]) {
+ checkFilterPushdown1(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[Integer]]) {
(2 to 4).map(Row.apply(_))
}
@@ -126,7 +133,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
checkFilterPushdown(rdd, '_1)(!('_1 < 4), classOf[Operators.GtEq[Integer]])(4)
checkFilterPushdown(rdd, '_1)('_1 > 2 && '_1 < 4, classOf[Operators.And])(3)
- checkFilterPushdown(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) {
+ checkFilterPushdown1(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) {
Seq(Row(1), Row(4))
}
}
@@ -134,13 +141,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
test("filter pushdown - long") {
withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { rdd =>
- checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Long]])(Seq.empty[Row])
- checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Long]]) {
+ checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Long]])(Seq.empty[Row])
+ checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Long]]) {
(1 to 4).map(Row.apply(_))
}
checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[java.lang.Long]])(1)
- checkFilterPushdown(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Long]]) {
+ checkFilterPushdown1(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Long]]) {
(2 to 4).map(Row.apply(_))
}
@@ -157,7 +164,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
checkFilterPushdown(rdd, '_1)(!('_1 < 4), classOf[Operators.GtEq[java.lang.Long]])(4)
checkFilterPushdown(rdd, '_1)('_1 > 2 && '_1 < 4, classOf[Operators.And])(3)
- checkFilterPushdown(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) {
+ checkFilterPushdown1(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) {
Seq(Row(1), Row(4))
}
}
@@ -165,13 +172,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
test("filter pushdown - float") {
withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { rdd =>
- checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Float]])(Seq.empty[Row])
- checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Float]]) {
+ checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Float]])(Seq.empty[Row])
+ checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Float]]) {
(1 to 4).map(Row.apply(_))
}
checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[java.lang.Float]])(1)
- checkFilterPushdown(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Float]]) {
+ checkFilterPushdown1(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Float]]) {
(2 to 4).map(Row.apply(_))
}
@@ -188,7 +195,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
checkFilterPushdown(rdd, '_1)(!('_1 < 4), classOf[Operators.GtEq[java.lang.Float]])(4)
checkFilterPushdown(rdd, '_1)('_1 > 2 && '_1 < 4, classOf[Operators.And])(3)
- checkFilterPushdown(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) {
+ checkFilterPushdown1(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) {
Seq(Row(1), Row(4))
}
}
@@ -196,13 +203,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
test("filter pushdown - double") {
withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { rdd =>
- checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Double]])(Seq.empty[Row])
- checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Double]]) {
+ checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Double]])(Seq.empty[Row])
+ checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Double]]) {
(1 to 4).map(Row.apply(_))
}
checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[java.lang.Double]])(1)
- checkFilterPushdown(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Double]]) {
+ checkFilterPushdown1(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Double]]) {
(2 to 4).map(Row.apply(_))
}
@@ -219,7 +226,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
checkFilterPushdown(rdd, '_1)(!('_1 < 4), classOf[Operators.GtEq[java.lang.Double]])(4)
checkFilterPushdown(rdd, '_1)('_1 > 2 && '_1 < 4, classOf[Operators.And])(3)
- checkFilterPushdown(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) {
+ checkFilterPushdown1(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) {
Seq(Row(1), Row(4))
}
}
@@ -227,30 +234,30 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
test("filter pushdown - string") {
withParquetRDD((1 to 4).map(i => Tuple1(i.toString))) { rdd =>
- checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.String]])(Seq.empty[Row])
- checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.String]]) {
+ checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.String]])(Seq.empty[Row])
+ checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.String]]) {
(1 to 4).map(i => Row.apply(i.toString))
}
- checkFilterPushdown(rdd, '_1)('_1 === "1", classOf[Eq[String]])("1")
- checkFilterPushdown(rdd, '_1)('_1 !== "1", classOf[Operators.NotEq[String]]) {
+ checkFilterPushdown1(rdd, '_1)('_1 === "1", classOf[Eq[String]])(Seq(Row("1")))
+ checkFilterPushdown1(rdd, '_1)('_1 !== "1", classOf[Operators.NotEq[String]]) {
(2 to 4).map(i => Row.apply(i.toString))
}
- checkFilterPushdown(rdd, '_1)('_1 < "2", classOf[Lt [java.lang.String]])("1")
- checkFilterPushdown(rdd, '_1)('_1 > "3", classOf[Gt [java.lang.String]])("4")
- checkFilterPushdown(rdd, '_1)('_1 <= "1", classOf[LtEq[java.lang.String]])("1")
- checkFilterPushdown(rdd, '_1)('_1 >= "4", classOf[GtEq[java.lang.String]])("4")
+ checkFilterPushdown1(rdd, '_1)('_1 < "2", classOf[Lt [java.lang.String]])(Seq(Row("1")))
+ checkFilterPushdown1(rdd, '_1)('_1 > "3", classOf[Gt [java.lang.String]])(Seq(Row("4")))
+ checkFilterPushdown1(rdd, '_1)('_1 <= "1", classOf[LtEq[java.lang.String]])(Seq(Row("1")))
+ checkFilterPushdown1(rdd, '_1)('_1 >= "4", classOf[GtEq[java.lang.String]])(Seq(Row("4")))
- checkFilterPushdown(rdd, '_1)(Literal("1") === '_1, classOf[Eq [java.lang.String]])("1")
- checkFilterPushdown(rdd, '_1)(Literal("2") > '_1, classOf[Lt [java.lang.String]])("1")
- checkFilterPushdown(rdd, '_1)(Literal("3") < '_1, classOf[Gt [java.lang.String]])("4")
- checkFilterPushdown(rdd, '_1)(Literal("1") >= '_1, classOf[LtEq[java.lang.String]])("1")
- checkFilterPushdown(rdd, '_1)(Literal("4") <= '_1, classOf[GtEq[java.lang.String]])("4")
+ checkFilterPushdown1(rdd, '_1)(Literal("1") === '_1, classOf[Eq [java.lang.String]])(Seq(Row("1")))
+ checkFilterPushdown1(rdd, '_1)(Literal("2") > '_1, classOf[Lt [java.lang.String]])(Seq(Row("1")))
+ checkFilterPushdown1(rdd, '_1)(Literal("3") < '_1, classOf[Gt [java.lang.String]])(Seq(Row("4")))
+ checkFilterPushdown1(rdd, '_1)(Literal("1") >= '_1, classOf[LtEq[java.lang.String]])(Seq(Row("1")))
+ checkFilterPushdown1(rdd, '_1)(Literal("4") <= '_1, classOf[GtEq[java.lang.String]])(Seq(Row("4")))
- checkFilterPushdown(rdd, '_1)(!('_1 < "4"), classOf[Operators.GtEq[java.lang.String]])("4")
- checkFilterPushdown(rdd, '_1)('_1 > "2" && '_1 < "4", classOf[Operators.And])("3")
- checkFilterPushdown(rdd, '_1)('_1 < "2" || '_1 > "3", classOf[Operators.Or]) {
+ checkFilterPushdown1(rdd, '_1)(!('_1 < "4"), classOf[Operators.GtEq[java.lang.String]])(Seq(Row("4")))
+ checkFilterPushdown1(rdd, '_1)('_1 > "2" && '_1 < "4", classOf[Operators.And])(Seq(Row("3")))
+ checkFilterPushdown1(rdd, '_1)('_1 < "2" || '_1 > "3", classOf[Operators.Or]) {
Seq(Row("1"), Row("4"))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
index 973819aaa4..a57e4e85a3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
@@ -68,8 +68,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest {
/**
* Writes `data` to a Parquet file, reads it back and check file contents.
*/
- protected def checkParquetFile[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = {
- withParquetRDD(data)(checkAnswer(_, data))
+ protected def checkParquetFile[T <: Product : ClassTag: TypeTag](data: Seq[T]): Unit = {
+ withParquetRDD(data)(r => checkAnswer(r, data.map(Row.fromTuple)))
}
test("basic data types (without binary)") {
@@ -143,7 +143,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest {
withParquetRDD(data) { rdd =>
// Structs are converted to `Row`s
checkAnswer(rdd, data.map { case Tuple1(struct) =>
- Tuple1(Row(struct.productIterator.toSeq: _*))
+ Row(Row(struct.productIterator.toSeq: _*))
})
}
}
@@ -153,7 +153,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest {
withParquetRDD(data) { rdd =>
// Structs are converted to `Row`s
checkAnswer(rdd, data.map { case Tuple1(struct) =>
- Tuple1(Row(struct.productIterator.toSeq: _*))
+ Row(Row(struct.productIterator.toSeq: _*))
})
}
}
@@ -162,7 +162,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest {
val data = (1 to 4).map(i => Tuple1(Map(i -> (i, s"val_$i"))))
withParquetRDD(data) { rdd =>
checkAnswer(rdd, data.map { case Tuple1(m) =>
- Tuple1(m.mapValues(struct => Row(struct.productIterator.toSeq: _*)))
+ Row(m.mapValues(struct => Row(struct.productIterator.toSeq: _*)))
})
}
}
@@ -261,7 +261,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest {
val path = new Path(dir.toURI.toString, "part-r-0.parquet")
makeRawParquetFile(path)
checkAnswer(parquetFile(path.toString), (0 until 10).map { i =>
- (i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble)
+ Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble)
})
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index 3a073a6b70..2c5345b1f9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -28,7 +28,7 @@ import parquet.hadoop.util.ContextUtil
import parquet.io.api.Binary
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.{Row => _, _}
import org.apache.spark.sql.catalyst.util.getTempFilePath
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
@@ -191,8 +191,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
parquetFile(path).registerTempTable("tmp")
checkAnswer(
sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"),
- (5, "val_5") ::
- (7, "val_7") :: Nil)
+ Row(5, "val_5") ::
+ Row(7, "val_7") :: Nil)
Utils.deleteRecursively(file)
@@ -207,8 +207,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
parquetFile(path).registerTempTable("tmp")
checkAnswer(
sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"),
- (5, "val_5") ::
- (7, "val_7") :: Nil)
+ Row(5, "val_5") ::
+ Row(7, "val_7") :: Nil)
Utils.deleteRecursively(file)
@@ -223,8 +223,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
parquetFile(path).registerTempTable("tmp")
checkAnswer(
sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"),
- (5, "val_5") ::
- (7, "val_7") :: Nil)
+ Row(5, "val_5") ::
+ Row(7, "val_7") :: Nil)
Utils.deleteRecursively(file)
@@ -239,8 +239,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
parquetFile(path).registerTempTable("tmp")
checkAnswer(
sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"),
- (5, "val_5") ::
- (7, "val_7") :: Nil)
+ Row(5, "val_5") ::
+ Row(7, "val_7") :: Nil)
Utils.deleteRecursively(file)
@@ -255,8 +255,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
parquetFile(path).registerTempTable("tmp")
checkAnswer(
sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"),
- (5, "val_5") ::
- (7, "val_7") :: Nil)
+ Row(5, "val_5") ::
+ Row(7, "val_7") :: Nil)
Utils.deleteRecursively(file)
@@ -303,7 +303,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
assert(result.size === 9, "self-join result has incorrect size")
assert(result(0).size === 12, "result row has incorrect size")
result.zipWithIndex.foreach {
- case (row, index) => row.zipWithIndex.foreach {
+ case (row, index) => row.toSeq.zipWithIndex.foreach {
case (field, column) => assert(field != null, s"self-join contains null value in row $index field $column")
}
}
@@ -423,7 +423,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
val readFile = parquetFile(path)
val rdd_saved = readFile.collect()
- assert(rdd_saved(0) === Seq.fill(5)(null))
+ assert(rdd_saved(0) === Row(null, null, null, null, null))
Utils.deleteRecursively(file)
assert(true)
}
@@ -438,7 +438,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
val readFile = parquetFile(path)
val rdd_saved = readFile.collect()
- assert(rdd_saved(0) === Seq.fill(5)(null))
+ assert(rdd_saved(0) === Row(null, null, null, null, null))
Utils.deleteRecursively(file)
assert(true)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala
index 4c081fb451..7b3f8c22af 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala
@@ -38,7 +38,7 @@ class ParquetQuerySuite2 extends QueryTest with ParquetTest {
val data = (0 until 10).map(i => (i, i.toString))
withParquetTable(data, "t") {
sql("INSERT INTO t SELECT * FROM t")
- checkAnswer(table("t"), data ++ data)
+ checkAnswer(table("t"), (data ++ data).map(Row.fromTuple))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index 264f6d94c4..b1e0919b7a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -244,7 +244,7 @@ class TableScanSuite extends DataSourceTest {
sqlTest(
"SELECT count(*) FROM tableWithSchema",
- 10)
+ Seq(Row(10)))
sqlTest(
"SELECT `string$%Field` FROM tableWithSchema",
@@ -260,7 +260,7 @@ class TableScanSuite extends DataSourceTest {
sqlTest(
"SELECT structFieldSimple.key, arrayFieldSimple[1] FROM tableWithSchema a where int_Field=1",
- Seq(Seq(1, 2)))
+ Seq(Row(1, 2)))
sqlTest(
"SELECT structFieldComplex.Value.`value_(2)` FROM tableWithSchema",
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 10833c1132..3e26fe3675 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -368,10 +368,10 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
.mkString("\t")
}
case command: ExecutedCommand =>
- command.executeCollect().map(_.head.toString)
+ command.executeCollect().map(_(0).toString)
case other =>
- val result: Seq[Seq[Any]] = other.executeCollect().toSeq
+ val result: Seq[Seq[Any]] = other.executeCollect().map(_.toSeq).toSeq
// We need the types so we can output struct field names
val types = analyzed.output.map(_.dataType)
// Reformat to match hive tab delimited output.
@@ -395,7 +395,7 @@ private object HiveContext {
protected[sql] def toHiveString(a: (Any, DataType)): String = a match {
case (struct: Row, StructType(fields)) =>
- struct.zip(fields).map {
+ struct.toSeq.zip(fields).map {
case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}"""
}.mkString("{", ",", "}")
case (seq: Seq[_], ArrayType(typ, _)) =>
@@ -418,7 +418,7 @@ private object HiveContext {
/** Hive outputs fields of structs slightly differently than top level attributes. */
protected def toHiveStructString(a: (Any, DataType)): String = a match {
case (struct: Row, StructType(fields)) =>
- struct.zip(fields).map {
+ struct.toSeq.zip(fields).map {
case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}"""
}.mkString("{", ",", "}")
case (seq: Seq[_], ArrayType(typ, _)) =>
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index eeabfdd857..82dba99900 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -348,7 +348,7 @@ private[hive] trait HiveInspectors {
(o: Any) => {
if (o != null) {
val struct = soi.create()
- (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row]).zipped.foreach {
+ (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row].toSeq).zipped.foreach {
(field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data))
}
struct
@@ -432,7 +432,7 @@ private[hive] trait HiveInspectors {
}
case x: SettableStructObjectInspector =>
val fieldRefs = x.getAllStructFieldRefs
- val row = a.asInstanceOf[Seq[_]]
+ val row = a.asInstanceOf[Row]
// 1. create the pojo (most likely) object
val result = x.create()
var i = 0
@@ -448,7 +448,7 @@ private[hive] trait HiveInspectors {
result
case x: StructObjectInspector =>
val fieldRefs = x.getAllStructFieldRefs
- val row = a.asInstanceOf[Seq[_]]
+ val row = a.asInstanceOf[Row]
val result = new java.util.ArrayList[AnyRef](fieldRefs.length)
var i = 0
while (i < fieldRefs.length) {
@@ -475,7 +475,7 @@ private[hive] trait HiveInspectors {
}
def wrap(
- row: Seq[Any],
+ row: Row,
inspectors: Seq[ObjectInspector],
cache: Array[AnyRef]): Array[AnyRef] = {
var i = 0
@@ -486,6 +486,18 @@ private[hive] trait HiveInspectors {
cache
}
+ def wrap(
+ row: Seq[Any],
+ inspectors: Seq[ObjectInspector],
+ cache: Array[AnyRef]): Array[AnyRef] = {
+ var i = 0
+ while (i < inspectors.length) {
+ cache(i) = wrap(row(i), inspectors(i))
+ i += 1
+ }
+ cache
+ }
+
/**
* @param dataType Catalyst data type
* @return Hive java object inspector (recursively), not the Writable ObjectInspector
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index d898b876c3..76d2140372 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -360,7 +360,7 @@ private[hive] case class HiveUdafFunction(
protected lazy val cached = new Array[AnyRef](exprs.length)
def update(input: Row): Unit = {
- val inputs = inputProjection(input).asInstanceOf[Seq[AnyRef]].toArray
+ val inputs = inputProjection(input)
function.iterate(buffer, wrap(inputs, inspectors, cached))
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
index cc8bb3e172..aae175e426 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
@@ -209,7 +209,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer(
override def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = {
val dynamicPartPath = dynamicPartColNames
- .zip(row.takeRight(dynamicPartColNames.length))
+ .zip(row.toSeq.takeRight(dynamicPartColNames.length))
.map { case (col, rawVal) =>
val string = if (rawVal == null) null else String.valueOf(rawVal)
s"/$col=${if (string == null || string.isEmpty) defaultPartName else string}"
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala
index f89c49d292..f320d732fb 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.util._
* So, we duplicate this code here.
*/
class QueryTest extends PlanTest {
+
/**
* Runs the plan and makes sure the answer contains all of the keywords, or the
* none of keywords are listed in the answer
@@ -56,17 +57,20 @@ class QueryTest extends PlanTest {
* @param rdd the [[SchemaRDD]] to be executed
* @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ].
*/
- protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Any): Unit = {
- val convertedAnswer = expectedAnswer match {
- case s: Seq[_] if s.isEmpty => s
- case s: Seq[_] if s.head.isInstanceOf[Product] &&
- !s.head.isInstanceOf[Seq[_]] => s.map(_.asInstanceOf[Product].productIterator.toIndexedSeq)
- case s: Seq[_] => s
- case singleItem => Seq(Seq(singleItem))
+ protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = {
+ val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
+ def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
+ // Converts data to types that we can do equality comparison using Scala collections.
+ // For BigDecimal type, the Scala type has a better definition of equality test (similar to
+ // Java's java.math.BigDecimal.compareTo).
+ val converted: Seq[Row] = answer.map { s =>
+ Row.fromSeq(s.toSeq.map {
+ case d: java.math.BigDecimal => BigDecimal(d)
+ case o => o
+ })
+ }
+ if (!isSorted) converted.sortBy(_.toString) else converted
}
-
- val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s}.nonEmpty
- def prepareAnswer(answer: Seq[Any]) = if (!isSorted) answer.sortBy(_.toString) else answer
val sparkAnswer = try rdd.collect().toSeq catch {
case e: Exception =>
fail(
@@ -74,11 +78,12 @@ class QueryTest extends PlanTest {
|Exception thrown while executing query:
|${rdd.queryExecution}
|== Exception ==
- |${stackTraceToString(e)}
+ |$e
+ |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
""".stripMargin)
}
- if(prepareAnswer(convertedAnswer) != prepareAnswer(sparkAnswer)) {
+ if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
fail(s"""
|Results do not match for query:
|${rdd.logicalPlan}
@@ -88,11 +93,22 @@ class QueryTest extends PlanTest {
|${rdd.queryExecution.executedPlan}
|== Results ==
|${sideBySide(
- s"== Correct Answer - ${convertedAnswer.size} ==" +:
- prepareAnswer(convertedAnswer).map(_.toString),
- s"== Spark Answer - ${sparkAnswer.size} ==" +:
- prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")}
+ s"== Correct Answer - ${expectedAnswer.size} ==" +:
+ prepareAnswer(expectedAnswer).map(_.toString),
+ s"== Spark Answer - ${sparkAnswer.size} ==" +:
+ prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")}
""".stripMargin)
}
}
+
+ protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = {
+ checkAnswer(rdd, Seq(expectedAnswer))
+ }
+
+ def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = {
+ test(sqlString) {
+ checkAnswer(sqlContext.sql(sqlString), expectedAnswer)
+ }
+ }
+
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
index 4864607252..2d3ff68012 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
@@ -129,6 +129,12 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors {
}
}
+ def checkValues(row1: Seq[Any], row2: Row): Unit = {
+ row1.zip(row2.toSeq).map {
+ case (r1, r2) => checkValue(r1, r2)
+ }
+ }
+
def checkValue(v1: Any, v2: Any): Unit = {
(v1, v2) match {
case (r1: Decimal, r2: Decimal) =>
@@ -198,7 +204,7 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors {
case (t, idx) => StructField(s"c_$idx", t)
})
- checkValues(row, unwrap(wrap(row, toInspector(dt)), toInspector(dt)).asInstanceOf[Row])
+ checkValues(row, unwrap(wrap(Row.fromSeq(row), toInspector(dt)), toInspector(dt)).asInstanceOf[Row])
checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt)))
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
index 7cfb875e05..0e6636d38e 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
@@ -43,7 +43,7 @@ class InsertIntoHiveTableSuite extends QueryTest {
// Make sure the table has also been updated.
checkAnswer(
sql("SELECT * FROM createAndInsertTest"),
- testData.collect().toSeq
+ testData.collect().toSeq.map(Row.fromTuple)
)
// Add more data.
@@ -52,7 +52,7 @@ class InsertIntoHiveTableSuite extends QueryTest {
// Make sure the table has been updated.
checkAnswer(
sql("SELECT * FROM createAndInsertTest"),
- testData.collect().toSeq ++ testData.collect().toSeq
+ testData.toSchemaRDD.collect().toSeq ++ testData.toSchemaRDD.collect().toSeq
)
// Now overwrite.
@@ -61,7 +61,7 @@ class InsertIntoHiveTableSuite extends QueryTest {
// Make sure the registered table has also been updated.
checkAnswer(
sql("SELECT * FROM createAndInsertTest"),
- testData.collect().toSeq
+ testData.collect().toSeq.map(Row.fromTuple)
)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index 53d8aa7739..7408c7ffd6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -155,7 +155,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
checkAnswer(
sql("SELECT * FROM jsonTable"),
- ("a", "b") :: Nil)
+ Row("a", "b"))
FileUtils.deleteDirectory(tempDir)
sparkContext.parallelize(("a1", "b1", "c1") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath)
@@ -164,14 +164,14 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
// will show.
checkAnswer(
sql("SELECT * FROM jsonTable"),
- ("a1", "b1") :: Nil)
+ Row("a1", "b1"))
refreshTable("jsonTable")
// Check that the refresh worked
checkAnswer(
sql("SELECT * FROM jsonTable"),
- ("a1", "b1", "c1") :: Nil)
+ Row("a1", "b1", "c1"))
FileUtils.deleteDirectory(tempDir)
}
@@ -191,7 +191,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
checkAnswer(
sql("SELECT * FROM jsonTable"),
- ("a", "b") :: Nil)
+ Row("a", "b"))
FileUtils.deleteDirectory(tempDir)
sparkContext.parallelize(("a", "b", "c") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath)
@@ -210,7 +210,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
// New table should reflect new schema.
checkAnswer(
sql("SELECT * FROM jsonTable"),
- ("a", "b", "c") :: Nil)
+ Row("a", "b", "c"))
FileUtils.deleteDirectory(tempDir)
}
@@ -253,6 +253,6 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
|)
""".stripMargin)
- sql("DROP TABLE jsonTable").collect.foreach(println)
+ sql("DROP TABLE jsonTable").collect().foreach(println)
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
index 0b4e76c9d3..6f07fd5a87 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
@@ -21,7 +21,7 @@ import org.scalatest.BeforeAndAfterAll
import scala.reflect.ClassTag
-import org.apache.spark.sql.{SQLConf, QueryTest}
+import org.apache.spark.sql.{Row, SQLConf, QueryTest}
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
@@ -141,7 +141,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
before: () => Unit,
after: () => Unit,
query: String,
- expectedAnswer: Seq[Any],
+ expectedAnswer: Seq[Row],
ct: ClassTag[_]) = {
before()
@@ -183,7 +183,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
/** Tests for MetastoreRelation */
val metastoreQuery = """SELECT * FROM src a JOIN src b ON a.key = 238 AND a.key = b.key"""
- val metastoreAnswer = Seq.fill(4)((238, "val_238", 238, "val_238"))
+ val metastoreAnswer = Seq.fill(4)(Row(238, "val_238", 238, "val_238"))
mkTest(
() => (),
() => (),
@@ -197,7 +197,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
val leftSemiJoinQuery =
"""SELECT * FROM src a
|left semi JOIN src b ON a.key=86 and a.key = b.key""".stripMargin
- val answer = (86, "val_86") :: Nil
+ val answer = Row(86, "val_86")
var rdd = sql(leftSemiJoinQuery)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index c14f0d24e0..df72be7746 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -226,7 +226,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
// Jdk version leads to different query output for double, so not use createQueryTest here
test("division") {
val res = sql("SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1").collect().head
- Seq(2.0, 0.5, 0.3333333333333333, 0.002).zip(res).foreach( x =>
+ Seq(2.0, 0.5, 0.3333333333333333, 0.002).zip(res.toSeq).foreach( x =>
assert(x._1 == x._2.asInstanceOf[Double]))
}
@@ -235,7 +235,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
test("Query expressed in SQL") {
setConf("spark.sql.dialect", "sql")
- assert(sql("SELECT 1").collect() === Array(Seq(1)))
+ assert(sql("SELECT 1").collect() === Array(Row(1)))
setConf("spark.sql.dialect", "hiveql")
}
@@ -467,7 +467,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
TestData(2, "str2") :: Nil)
testData.registerTempTable("REGisteredTABle")
- assertResult(Array(Array(2, "str2"))) {
+ assertResult(Array(Row(2, "str2"))) {
sql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " +
"WHERE TableAliaS.a > 1").collect()
}
@@ -553,12 +553,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
// Describe a table
assertResult(
Array(
- Array("key", "int", null),
- Array("value", "string", null),
- Array("dt", "string", null),
- Array("# Partition Information", "", ""),
- Array("# col_name", "data_type", "comment"),
- Array("dt", "string", null))
+ Row("key", "int", null),
+ Row("value", "string", null),
+ Row("dt", "string", null),
+ Row("# Partition Information", "", ""),
+ Row("# col_name", "data_type", "comment"),
+ Row("dt", "string", null))
) {
sql("DESCRIBE test_describe_commands1")
.select('col_name, 'data_type, 'comment)
@@ -568,12 +568,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
// Describe a table with a fully qualified table name
assertResult(
Array(
- Array("key", "int", null),
- Array("value", "string", null),
- Array("dt", "string", null),
- Array("# Partition Information", "", ""),
- Array("# col_name", "data_type", "comment"),
- Array("dt", "string", null))
+ Row("key", "int", null),
+ Row("value", "string", null),
+ Row("dt", "string", null),
+ Row("# Partition Information", "", ""),
+ Row("# col_name", "data_type", "comment"),
+ Row("dt", "string", null))
) {
sql("DESCRIBE default.test_describe_commands1")
.select('col_name, 'data_type, 'comment)
@@ -623,8 +623,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
assertResult(
Array(
- Array("a", "IntegerType", null),
- Array("b", "StringType", null))
+ Row("a", "IntegerType", null),
+ Row("b", "StringType", null))
) {
sql("DESCRIBE test_describe_commands2")
.select('col_name, 'data_type, 'comment)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
index 5dafcd6c0a..f2374a2152 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
@@ -64,7 +64,7 @@ class HiveUdfSuite extends QueryTest {
test("SPARK-4785 When called with arguments referring column fields, PMOD throws NPE") {
checkAnswer(
sql("SELECT PMOD(CAST(key as INT), 10) FROM src LIMIT 1"),
- 8
+ Row(8)
)
}
@@ -115,7 +115,7 @@ class HiveUdfSuite extends QueryTest {
sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '${classOf[UDFIntegerToString].getName}'")
checkAnswer(
sql("SELECT testUDFIntegerToString(i) FROM integerTable"), //.collect(),
- Seq(Seq("1"), Seq("2")))
+ Seq(Row("1"), Row("2")))
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString")
TestHive.reset()
@@ -131,7 +131,7 @@ class HiveUdfSuite extends QueryTest {
sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'")
checkAnswer(
sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), //.collect(),
- Seq(Seq(0), Seq(2), Seq(13)))
+ Seq(Row(0), Row(2), Row(13)))
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt")
TestHive.reset()
@@ -146,7 +146,7 @@ class HiveUdfSuite extends QueryTest {
sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'")
checkAnswer(
sql("SELECT testUDFListString(l) FROM listStringTable"), //.collect(),
- Seq(Seq("a,b,c"), Seq("d,e")))
+ Seq(Row("a,b,c"), Row("d,e")))
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString")
TestHive.reset()
@@ -160,7 +160,7 @@ class HiveUdfSuite extends QueryTest {
sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'")
checkAnswer(
sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), //.collect(),
- Seq(Seq("hello world"), Seq("hello goodbye")))
+ Seq(Row("hello world"), Row("hello goodbye")))
sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUdf")
TestHive.reset()
@@ -177,7 +177,7 @@ class HiveUdfSuite extends QueryTest {
sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'")
checkAnswer(
sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), //.collect(),
- Seq(Seq("0, 0"), Seq("2, 2"), Seq("13, 13")))
+ Seq(Row("0, 0"), Row("2, 2"), Row("13, 13")))
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList")
TestHive.reset()
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index d41eb9e870..f6bf2dbb5d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -41,7 +41,7 @@ class SQLQuerySuite extends QueryTest {
}
test("CTAS with serde") {
- sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect
+ sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect()
sql(
"""CREATE TABLE ctas2
| ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe"
@@ -51,23 +51,23 @@ class SQLQuerySuite extends QueryTest {
| AS
| SELECT key, value
| FROM src
- | ORDER BY key, value""".stripMargin).collect
+ | ORDER BY key, value""".stripMargin).collect()
sql(
"""CREATE TABLE ctas3
| ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\012'
| STORED AS textfile AS
| SELECT key, value
| FROM src
- | ORDER BY key, value""".stripMargin).collect
+ | ORDER BY key, value""".stripMargin).collect()
// the table schema may like (key: integer, value: string)
sql(
"""CREATE TABLE IF NOT EXISTS ctas4 AS
- | SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin).collect
+ | SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin).collect()
// do nothing cause the table ctas4 already existed.
sql(
"""CREATE TABLE IF NOT EXISTS ctas4 AS
- | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect
+ | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect()
checkAnswer(
sql("SELECT k, value FROM ctas1 ORDER BY k, value"),
@@ -89,7 +89,7 @@ class SQLQuerySuite extends QueryTest {
intercept[org.apache.hadoop.hive.metastore.api.AlreadyExistsException] {
sql(
"""CREATE TABLE ctas4 AS
- | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect
+ | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect()
}
checkAnswer(
sql("SELECT key, value FROM ctas4 ORDER BY key, value"),
@@ -126,7 +126,7 @@ class SQLQuerySuite extends QueryTest {
sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil).registerTempTable("nested")
checkAnswer(
sql("SELECT f1.f2.f3 FROM nested"),
- 1)
+ Row(1))
checkAnswer(sql("CREATE TABLE test_ctas_1234 AS SELECT * from nested"),
Seq.empty[Row])
checkAnswer(
@@ -233,7 +233,7 @@ class SQLQuerySuite extends QueryTest {
| (s struct<innerStruct: struct<s1:string>,
| innerArray:array<int>,
| innerMap: map<string, int>>)
- """.stripMargin).collect
+ """.stripMargin).collect()
sql(
"""
@@ -243,7 +243,7 @@ class SQLQuerySuite extends QueryTest {
checkAnswer(
sql("SELECT * FROM nullValuesInInnerComplexTypes"),
- Seq(Seq(Seq(null, null, null)))
+ Row(Row(null, null, null))
)
sql("DROP TABLE nullValuesInInnerComplexTypes")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala
index 4bc14bad0a..581f666399 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala
@@ -39,7 +39,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest {
test("SELECT on Parquet table") {
val data = (1 to 4).map(i => (i, s"val_$i"))
withParquetTable(data, "t") {
- checkAnswer(sql("SELECT * FROM t"), data)
+ checkAnswer(sql("SELECT * FROM t"), data.map(Row.fromTuple))
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala
index 8bbb7f2fdb..79fd99d9f8 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala
@@ -177,81 +177,81 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll
test(s"ordering of the partitioning columns $table") {
checkAnswer(
sql(s"SELECT p, stringField FROM $table WHERE p = 1"),
- Seq.fill(10)((1, "part-1"))
+ Seq.fill(10)(Row(1, "part-1"))
)
checkAnswer(
sql(s"SELECT stringField, p FROM $table WHERE p = 1"),
- Seq.fill(10)(("part-1", 1))
+ Seq.fill(10)(Row("part-1", 1))
)
}
test(s"project the partitioning column $table") {
checkAnswer(
sql(s"SELECT p, count(*) FROM $table group by p"),
- (1, 10) ::
- (2, 10) ::
- (3, 10) ::
- (4, 10) ::
- (5, 10) ::
- (6, 10) ::
- (7, 10) ::
- (8, 10) ::
- (9, 10) ::
- (10, 10) :: Nil
+ Row(1, 10) ::
+ Row(2, 10) ::
+ Row(3, 10) ::
+ Row(4, 10) ::
+ Row(5, 10) ::
+ Row(6, 10) ::
+ Row(7, 10) ::
+ Row(8, 10) ::
+ Row(9, 10) ::
+ Row(10, 10) :: Nil
)
}
test(s"project partitioning and non-partitioning columns $table") {
checkAnswer(
sql(s"SELECT stringField, p, count(intField) FROM $table GROUP BY p, stringField"),
- ("part-1", 1, 10) ::
- ("part-2", 2, 10) ::
- ("part-3", 3, 10) ::
- ("part-4", 4, 10) ::
- ("part-5", 5, 10) ::
- ("part-6", 6, 10) ::
- ("part-7", 7, 10) ::
- ("part-8", 8, 10) ::
- ("part-9", 9, 10) ::
- ("part-10", 10, 10) :: Nil
+ Row("part-1", 1, 10) ::
+ Row("part-2", 2, 10) ::
+ Row("part-3", 3, 10) ::
+ Row("part-4", 4, 10) ::
+ Row("part-5", 5, 10) ::
+ Row("part-6", 6, 10) ::
+ Row("part-7", 7, 10) ::
+ Row("part-8", 8, 10) ::
+ Row("part-9", 9, 10) ::
+ Row("part-10", 10, 10) :: Nil
)
}
test(s"simple count $table") {
checkAnswer(
sql(s"SELECT COUNT(*) FROM $table"),
- 100)
+ Row(100))
}
test(s"pruned count $table") {
checkAnswer(
sql(s"SELECT COUNT(*) FROM $table WHERE p = 1"),
- 10)
+ Row(10))
}
test(s"non-existant partition $table") {
checkAnswer(
sql(s"SELECT COUNT(*) FROM $table WHERE p = 1000"),
- 0)
+ Row(0))
}
test(s"multi-partition pruned count $table") {
checkAnswer(
sql(s"SELECT COUNT(*) FROM $table WHERE p IN (1,2,3)"),
- 30)
+ Row(30))
}
test(s"non-partition predicates $table") {
checkAnswer(
sql(s"SELECT COUNT(*) FROM $table WHERE intField IN (1,2,3)"),
- 30)
+ Row(30))
}
test(s"sum $table") {
checkAnswer(
sql(s"SELECT SUM(intField) FROM $table WHERE intField IN (1,2,3) AND p = 1"),
- 1 + 2 + 3)
+ Row(1 + 2 + 3))
}
test(s"hive udfs $table") {
@@ -266,6 +266,6 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll
test("non-part select(*)") {
checkAnswer(
sql("SELECT COUNT(*) FROM normal_parquet"),
- 10)
+ Row(10))
}
}