aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-11-15 13:59:59 -0800
committerDavies Liu <davies.liu@gmail.com>2015-11-15 13:59:59 -0800
commit3e2e1873b2762d07e49de8f9ea709bf3fa2d171c (patch)
tree7b4656a38863a4757dbcfc6d29405d8ead60d800
parent64e55511033afb6ef42be142eb371bfbc31f5230 (diff)
downloadspark-3e2e1873b2762d07e49de8f9ea709bf3fa2d171c.tar.gz
spark-3e2e1873b2762d07e49de8f9ea709bf3fa2d171c.tar.bz2
spark-3e2e1873b2762d07e49de8f9ea709bf3fa2d171c.zip
[SPARK-11738] [SQL] Making ArrayType orderable
https://issues.apache.org/jira/browse/SPARK-11738 Author: Yin Huai <yhuai@databricks.com> Closes #9718 from yhuai/makingArrayOrderable.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala32
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala43
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala48
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala37
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala32
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala36
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala124
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala12
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala52
14 files changed, 335 insertions, 94 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 5a4b0c1e39..7b2c93d63d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -137,32 +137,14 @@ trait CheckAnalysis {
case e => e.children.foreach(checkValidAggregateExpression)
}
- def checkSupportedGroupingDataType(
- expressionString: String,
- dataType: DataType): Unit = dataType match {
- case BinaryType =>
- failAnalysis(s"expression $expressionString cannot be used in " +
- s"grouping expression because it is in binary type or its inner field is " +
- s"in binary type")
- case a: ArrayType =>
- failAnalysis(s"expression $expressionString cannot be used in " +
- s"grouping expression because it is in array type or its inner field is " +
- s"in array type")
- case m: MapType =>
- failAnalysis(s"expression $expressionString cannot be used in " +
- s"grouping expression because it is in map type or its inner field is " +
- s"in map type")
- case s: StructType =>
- s.fields.foreach { f =>
- checkSupportedGroupingDataType(expressionString, f.dataType)
- }
- case udt: UserDefinedType[_] =>
- checkSupportedGroupingDataType(expressionString, udt.sqlType)
- case _ => // OK
- }
-
def checkValidGroupingExprs(expr: Expression): Unit = {
- checkSupportedGroupingDataType(expr.prettyString, expr.dataType)
+ // Check if the data type of expr is orderable.
+ if (!RowOrdering.isOrderable(expr.dataType)) {
+ failAnalysis(
+ s"expression ${expr.prettyString} cannot be used as a grouping expression " +
+ s"because its data type ${expr.dataType.simpleString} is not a orderable " +
+ s"data type.")
+ }
if (!expr.deterministic) {
// This is just a sanity check, our analysis rule PullOutNondeterministic should
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index ccd91d3549..1718cfbd35 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -267,6 +267,49 @@ class CodeGenContext {
case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
case NullType => "0"
+ case array: ArrayType =>
+ val elementType = array.elementType
+ val elementA = freshName("elementA")
+ val isNullA = freshName("isNullA")
+ val elementB = freshName("elementB")
+ val isNullB = freshName("isNullB")
+ val compareFunc = freshName("compareArray")
+ val minLength = freshName("minLength")
+ val funcCode: String =
+ s"""
+ public int $compareFunc(ArrayData a, ArrayData b) {
+ int lengthA = a.numElements();
+ int lengthB = b.numElements();
+ int $minLength = (lengthA > lengthB) ? lengthB : lengthA;
+ for (int i = 0; i < $minLength; i++) {
+ boolean $isNullA = a.isNullAt(i);
+ boolean $isNullB = b.isNullAt(i);
+ if ($isNullA && $isNullB) {
+ // Nothing
+ } else if ($isNullA) {
+ return -1;
+ } else if ($isNullB) {
+ return 1;
+ } else {
+ ${javaType(elementType)} $elementA = ${getValue("a", elementType, "i")};
+ ${javaType(elementType)} $elementB = ${getValue("b", elementType, "i")};
+ int comp = ${genComp(elementType, elementA, elementB)};
+ if (comp != 0) {
+ return comp;
+ }
+ }
+ }
+
+ if (lengthA < lengthB) {
+ return -1;
+ } else if (lengthA > lengthB) {
+ return 1;
+ }
+ return 0;
+ }
+ """
+ addNewFunction(compareFunc, funcCode)
+ s"this.$compareFunc($c1, $c2)"
case schema: StructType =>
val comparisons = GenerateOrdering.genComparisons(this, schema)
val compareFunc = freshName("compareStruct")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 2cf19b939f..741ad1f3ef 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -68,6 +68,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
private lazy val lt: Comparator[Any] = {
val ordering = base.dataType match {
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
+ case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
}
@@ -90,6 +91,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
private lazy val gt: Comparator[Any] = {
val ordering = base.dataType match {
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
+ case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala
index 6407c73bc9..6112259fed 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala
@@ -48,6 +48,10 @@ class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow
dt.ordering.asInstanceOf[Ordering[Any]].compare(left, right)
case dt: AtomicType if order.direction == Descending =>
dt.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
+ case a: ArrayType if order.direction == Ascending =>
+ a.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right)
+ case a: ArrayType if order.direction == Descending =>
+ a.interpretedOrdering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
case s: StructType if order.direction == Ascending =>
s.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right)
case s: StructType if order.direction == Descending =>
@@ -86,6 +90,8 @@ object RowOrdering {
case NullType => true
case dt: AtomicType => true
case struct: StructType => struct.fields.forall(f => isOrderable(f.dataType))
+ case array: ArrayType => isOrderable(array.elementType)
+ case udt: UserDefinedType[_] => isOrderable(udt.sqlType)
case _ => false
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
index bcf4d78fb9..f603cbfb0c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
@@ -57,6 +57,7 @@ object TypeUtils {
def getInterpretedOrdering(t: DataType): Ordering[Any] = {
t match {
case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
+ case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
index 1d2d007c2b..a5ae8bb0e5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
@@ -84,6 +84,7 @@ private[sql] object TypeCollection {
* Types that can be ordered/compared. In the long run we should probably make this a trait
* that can be mixed into each data type, and perhaps create an [[AbstractDataType]].
*/
+ // TODO: Should we consolidate this with RowOrdering.isOrderable?
val Ordered = TypeCollection(
BooleanType,
ByteType, ShortType, IntegerType, LongType,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
index 5770f59b53..a001eadcc6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
@@ -17,10 +17,13 @@
package org.apache.spark.sql.types
+import org.apache.spark.sql.catalyst.util.ArrayData
import org.json4s.JsonDSL._
import org.apache.spark.annotation.DeveloperApi
+import scala.math.Ordering
+
object ArrayType extends AbstractDataType {
/** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */
@@ -81,4 +84,49 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = {
f(this) || elementType.existsRecursively(f)
}
+
+ @transient
+ private[sql] lazy val interpretedOrdering: Ordering[ArrayData] = new Ordering[ArrayData] {
+ private[this] val elementOrdering: Ordering[Any] = elementType match {
+ case dt: AtomicType => dt.ordering.asInstanceOf[Ordering[Any]]
+ case a : ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
+ case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
+ case other =>
+ throw new IllegalArgumentException(s"Type $other does not support ordered operations")
+ }
+
+ def compare(x: ArrayData, y: ArrayData): Int = {
+ val leftArray = x
+ val rightArray = y
+ val minLength = scala.math.min(leftArray.numElements(), rightArray.numElements())
+ var i = 0
+ while (i < minLength) {
+ val isNullLeft = leftArray.isNullAt(i)
+ val isNullRight = rightArray.isNullAt(i)
+ if (isNullLeft && isNullRight) {
+ // Do nothing.
+ } else if (isNullLeft) {
+ return -1
+ } else if (isNullRight) {
+ return 1
+ } else {
+ val comp =
+ elementOrdering.compare(
+ leftArray.get(i, elementType),
+ rightArray.get(i, elementType))
+ if (comp != 0) {
+ return comp
+ }
+ }
+ i += 1
+ }
+ if (leftArray.numElements() < rightArray.numElements()) {
+ return -1
+ } else if (leftArray.numElements() > rightArray.numElements()) {
+ return 1
+ } else {
+ return 0
+ }
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 2e7c3bd67b..ee43557874 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData}
+import org.apache.spark.sql.catalyst.util.{MapData, ArrayBasedMapData, GenericArrayData, ArrayData}
import org.apache.spark.sql.types._
import scala.beans.{BeanProperty, BeanInfo}
@@ -53,21 +53,29 @@ private[sql] class GroupableUDT extends UserDefinedType[GroupableData] {
}
@BeanInfo
-private[sql] case class UngroupableData(@BeanProperty data: Array[Int])
+private[sql] case class UngroupableData(@BeanProperty data: Map[Int, Int])
private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] {
- override def sqlType: DataType = ArrayType(IntegerType)
+ override def sqlType: DataType = MapType(IntegerType, IntegerType)
- override def serialize(obj: Any): ArrayData = {
+ override def serialize(obj: Any): MapData = {
obj match {
- case groupableData: UngroupableData => new GenericArrayData(groupableData.data)
+ case groupableData: UngroupableData =>
+ val keyArray = new GenericArrayData(groupableData.data.keys.toSeq)
+ val valueArray = new GenericArrayData(groupableData.data.values.toSeq)
+ new ArrayBasedMapData(keyArray, valueArray)
}
}
override def deserialize(datum: Any): UngroupableData = {
datum match {
- case data: Array[Int] => UngroupableData(data)
+ case data: MapData =>
+ val keyArray = data.keyArray().array
+ val valueArray = data.valueArray().array
+ assert(keyArray.length == valueArray.length)
+ val mapData = keyArray.zip(valueArray).toMap.asInstanceOf[Map[Int, Int]]
+ UngroupableData(mapData)
}
}
@@ -154,8 +162,8 @@ class AnalysisErrorSuite extends AnalysisTest {
errorTest(
"sorting by unsupported column types",
- listRelation.orderBy('list.asc),
- "sort" :: "type" :: "array<int>" :: Nil)
+ mapRelation.orderBy('map.asc),
+ "sort" :: "type" :: "map<int,int>" :: Nil)
errorTest(
"non-boolean filters",
@@ -259,32 +267,33 @@ class AnalysisErrorSuite extends AnalysisTest {
case true =>
assertAnalysisSuccess(plan, true)
case false =>
- assertAnalysisError(plan, "expression a cannot be used in grouping expression" :: Nil)
+ assertAnalysisError(plan, "expression a cannot be used as a grouping expression" :: Nil)
}
-
}
val supportedDataTypes = Seq(
- StringType,
+ StringType, BinaryType,
NullType, BooleanType,
ByteType, ShortType, IntegerType, LongType,
FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5),
DateType, TimestampType,
+ ArrayType(IntegerType),
new StructType()
.add("f1", FloatType, nullable = true)
.add("f2", StringType, nullable = true),
+ new StructType()
+ .add("f1", FloatType, nullable = true)
+ .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true),
new GroupableUDT())
supportedDataTypes.foreach { dataType =>
checkDataType(dataType, shouldSuccess = true)
}
val unsupportedDataTypes = Seq(
- BinaryType,
- ArrayType(IntegerType),
MapType(StringType, LongType),
new StructType()
.add("f1", FloatType, nullable = true)
- .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true),
+ .add("f2", MapType(StringType, LongType), nullable = true),
new UngroupableUDT())
unsupportedDataTypes.foreach { dataType =>
checkDataType(dataType, shouldSuccess = false)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index b902982add..ba1866efc8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
-import org.apache.spark.sql.types.{TypeCollection, StringType}
+import org.apache.spark.sql.types.{LongType, TypeCollection, StringType}
class ExpressionTypeCheckingSuite extends SparkFunSuite {
@@ -32,7 +32,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
'intField.int,
'stringField.string,
'booleanField.boolean,
- 'complexField.array(StringType))
+ 'arrayField.array(StringType),
+ 'mapField.map(StringType, LongType))
def assertError(expr: Expression, errorMessage: String): Unit = {
val e = intercept[AnalysisException] {
@@ -90,9 +91,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertError(BitwiseOr('booleanField, 'booleanField), "requires integral type")
assertError(BitwiseXor('booleanField, 'booleanField), "requires integral type")
- assertError(MaxOf('complexField, 'complexField),
+ assertError(MaxOf('mapField, 'mapField),
s"requires ${TypeCollection.Ordered.simpleString} type")
- assertError(MinOf('complexField, 'complexField),
+ assertError(MinOf('mapField, 'mapField),
s"requires ${TypeCollection.Ordered.simpleString} type")
}
@@ -109,20 +110,20 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertSuccess(EqualTo('intField, 'booleanField))
assertSuccess(EqualNullSafe('intField, 'booleanField))
- assertErrorForDifferingTypes(EqualTo('intField, 'complexField))
- assertErrorForDifferingTypes(EqualNullSafe('intField, 'complexField))
+ assertErrorForDifferingTypes(EqualTo('intField, 'mapField))
+ assertErrorForDifferingTypes(EqualNullSafe('intField, 'mapField))
assertErrorForDifferingTypes(LessThan('intField, 'booleanField))
assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField))
assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField))
assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField))
- assertError(LessThan('complexField, 'complexField),
+ assertError(LessThan('mapField, 'mapField),
s"requires ${TypeCollection.Ordered.simpleString} type")
- assertError(LessThanOrEqual('complexField, 'complexField),
+ assertError(LessThanOrEqual('mapField, 'mapField),
s"requires ${TypeCollection.Ordered.simpleString} type")
- assertError(GreaterThan('complexField, 'complexField),
+ assertError(GreaterThan('mapField, 'mapField),
s"requires ${TypeCollection.Ordered.simpleString} type")
- assertError(GreaterThanOrEqual('complexField, 'complexField),
+ assertError(GreaterThanOrEqual('mapField, 'mapField),
s"requires ${TypeCollection.Ordered.simpleString} type")
assertError(If('intField, 'stringField, 'stringField),
@@ -130,10 +131,10 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField))
assertError(
- CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'complexField)),
+ CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'mapField)),
"THEN and ELSE expressions should all be same type or coercible to a common type")
assertError(
- CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'complexField)),
+ CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'mapField)),
"THEN and ELSE expressions should all be same type or coercible to a common type")
assertError(
CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)),
@@ -147,9 +148,10 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
// We will cast String to Double for sum and average
assertSuccess(Sum('stringField))
assertSuccess(Average('stringField))
+ assertSuccess(Min('arrayField))
- assertError(Min('complexField), "min does not support ordering on type")
- assertError(Max('complexField), "max does not support ordering on type")
+ assertError(Min('mapField), "min does not support ordering on type")
+ assertError(Max('mapField), "max does not support ordering on type")
assertError(Sum('booleanField), "function sum requires numeric type")
assertError(Average('booleanField), "function average requires numeric type")
}
@@ -184,7 +186,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertError(Round('intField, 'intField), "Only foldable Expression is allowed")
assertError(Round('intField, 'booleanField), "requires int type")
- assertError(Round('intField, 'complexField), "requires int type")
+ assertError(Round('intField, 'mapField), "requires int type")
assertError(Round('booleanField, 'intField), "requires numeric type")
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala
index 05b870705e..bc07b609a3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala
@@ -48,4 +48,7 @@ object TestRelations {
val listRelation = LocalRelation(
AttributeReference("list", ArrayType(IntegerType))())
+
+ val mapRelation = LocalRelation(
+ AttributeReference("map", MapType(IntegerType, IntegerType))())
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index e323467af5..002ed16dcf 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.expressions
-import scala.math._
-
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{Row, RandomDataGenerator}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
@@ -49,40 +47,6 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
futures.foreach(Await.result(_, 10.seconds))
}
- // Test GenerateOrdering for all common types. For each type, we construct random input rows that
- // contain two columns of that type, then for pairs of randomly-generated rows we check that
- // GenerateOrdering agrees with RowOrdering.
- (DataTypeTestUtils.atomicTypes ++ Set(NullType)).foreach { dataType =>
- test(s"GenerateOrdering with $dataType") {
- val rowOrdering = InterpretedOrdering.forSchema(Seq(dataType, dataType))
- val genOrdering = GenerateOrdering.generate(
- BoundReference(0, dataType, nullable = true).asc ::
- BoundReference(1, dataType, nullable = true).asc :: Nil)
- val rowType = StructType(
- StructField("a", dataType, nullable = true) ::
- StructField("b", dataType, nullable = true) :: Nil)
- val maybeDataGenerator = RandomDataGenerator.forType(rowType, nullable = false)
- assume(maybeDataGenerator.isDefined)
- val randGenerator = maybeDataGenerator.get
- val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType)
- for (_ <- 1 to 50) {
- val a = toCatalyst(randGenerator()).asInstanceOf[InternalRow]
- val b = toCatalyst(randGenerator()).asInstanceOf[InternalRow]
- withClue(s"a = $a, b = $b") {
- assert(genOrdering.compare(a, a) === 0)
- assert(genOrdering.compare(b, b) === 0)
- assert(rowOrdering.compare(a, a) === 0)
- assert(rowOrdering.compare(b, b) === 0)
- assert(signum(genOrdering.compare(a, b)) === -1 * signum(genOrdering.compare(b, a)))
- assert(signum(rowOrdering.compare(a, b)) === -1 * signum(rowOrdering.compare(b, a)))
- assert(
- signum(rowOrdering.compare(a, b)) === signum(genOrdering.compare(a, b)),
- "Generated and non-generated orderings should agree")
- }
- }
- }
- }
-
test("SPARK-8443: split wide projections into blocks due to JVM code size limit") {
val length = 5000
val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1)))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala
new file mode 100644
index 0000000000..7ad8657bde
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import scala.math._
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.{Row, RandomDataGenerator}
+import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters}
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering
+import org.apache.spark.sql.types._
+
+class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper {
+
+ def compareArrays(a: Seq[Any], b: Seq[Any], expected: Int): Unit = {
+ test(s"compare two arrays: a = $a, b = $b") {
+ val dataType = ArrayType(IntegerType)
+ val rowType = StructType(StructField("array", dataType, nullable = true) :: Nil)
+ val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType)
+ val rowA = toCatalyst(Row(a)).asInstanceOf[InternalRow]
+ val rowB = toCatalyst(Row(b)).asInstanceOf[InternalRow]
+ Seq(Ascending, Descending).foreach { direction =>
+ val sortOrder = direction match {
+ case Ascending => BoundReference(0, dataType, nullable = true).asc
+ case Descending => BoundReference(0, dataType, nullable = true).desc
+ }
+ val expectedCompareResult = direction match {
+ case Ascending => signum(expected)
+ case Descending => -1 * signum(expected)
+ }
+ val intOrdering = new InterpretedOrdering(sortOrder :: Nil)
+ val genOrdering = GenerateOrdering.generate(sortOrder :: Nil)
+ Seq(intOrdering, genOrdering).foreach { ordering =>
+ assert(ordering.compare(rowA, rowA) === 0)
+ assert(ordering.compare(rowB, rowB) === 0)
+ assert(signum(ordering.compare(rowA, rowB)) === expectedCompareResult)
+ assert(signum(ordering.compare(rowB, rowA)) === -1 * expectedCompareResult)
+ }
+ }
+ }
+ }
+
+ // Two arrays have the same size.
+ compareArrays(Seq[Any](), Seq[Any](), 0)
+ compareArrays(Seq[Any](1), Seq[Any](1), 0)
+ compareArrays(Seq[Any](1, 2), Seq[Any](1, 2), 0)
+ compareArrays(Seq[Any](1, 2, 2), Seq[Any](1, 2, 3), -1)
+
+ // Two arrays have different sizes.
+ compareArrays(Seq[Any](), Seq[Any](1), -1)
+ compareArrays(Seq[Any](1, 2, 3), Seq[Any](1, 2, 3, 4), -1)
+ compareArrays(Seq[Any](1, 2, 3), Seq[Any](1, 2, 3, 2), -1)
+ compareArrays(Seq[Any](1, 2, 3), Seq[Any](1, 2, 2, 2), 1)
+
+ // Arrays having nulls.
+ compareArrays(Seq[Any](1, 2, 3), Seq[Any](1, 2, 3, null), -1)
+ compareArrays(Seq[Any](), Seq[Any](null), -1)
+ compareArrays(Seq[Any](null), Seq[Any](null), 0)
+ compareArrays(Seq[Any](null, null), Seq[Any](null, null), 0)
+ compareArrays(Seq[Any](null), Seq[Any](null, null), -1)
+ compareArrays(Seq[Any](null), Seq[Any](1), -1)
+ compareArrays(Seq[Any](null), Seq[Any](null, 1), -1)
+ compareArrays(Seq[Any](null, 1), Seq[Any](1, 1), -1)
+ compareArrays(Seq[Any](1, null, 1), Seq[Any](1, null, 1), 0)
+ compareArrays(Seq[Any](1, null, 1), Seq[Any](1, null, 2), -1)
+
+ // Test GenerateOrdering for all common types. For each type, we construct random input rows that
+ // contain two columns of that type, then for pairs of randomly-generated rows we check that
+ // GenerateOrdering agrees with RowOrdering.
+ {
+ val structType =
+ new StructType()
+ .add("f1", FloatType, nullable = true)
+ .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true)
+ val arrayOfStructType = ArrayType(structType)
+ val complexTypes = ArrayType(IntegerType) :: structType :: arrayOfStructType :: Nil
+ (DataTypeTestUtils.atomicTypes ++ complexTypes ++ Set(NullType)).foreach { dataType =>
+ test(s"GenerateOrdering with $dataType") {
+ val rowOrdering = InterpretedOrdering.forSchema(Seq(dataType, dataType))
+ val genOrdering = GenerateOrdering.generate(
+ BoundReference(0, dataType, nullable = true).asc ::
+ BoundReference(1, dataType, nullable = true).asc :: Nil)
+ val rowType = StructType(
+ StructField("a", dataType, nullable = true) ::
+ StructField("b", dataType, nullable = true) :: Nil)
+ val maybeDataGenerator = RandomDataGenerator.forType(rowType, nullable = false)
+ assume(maybeDataGenerator.isDefined)
+ val randGenerator = maybeDataGenerator.get
+ val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType)
+ for (_ <- 1 to 50) {
+ val a = toCatalyst(randGenerator()).asInstanceOf[InternalRow]
+ val b = toCatalyst(randGenerator()).asInstanceOf[InternalRow]
+ withClue(s"a = $a, b = $b") {
+ assert(genOrdering.compare(a, a) === 0)
+ assert(genOrdering.compare(b, b) === 0)
+ assert(rowOrdering.compare(a, a) === 0)
+ assert(rowOrdering.compare(b, b) === 0)
+ assert(signum(genOrdering.compare(a, b)) === -1 * signum(genOrdering.compare(b, a)))
+ assert(signum(rowOrdering.compare(a, b)) === -1 * signum(rowOrdering.compare(b, a)))
+ assert(
+ signum(rowOrdering.compare(a, b)) === signum(genOrdering.compare(a, b)),
+ "Generated and non-generated orderings should agree")
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 3a3f19af14..aff9efe4b2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -308,10 +308,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
Row(null, null))
)
- val df2 = Seq((Array[Array[Int]](Array(2)), "x")).toDF("a", "b")
- assert(intercept[AnalysisException] {
- df2.selectExpr("sort_array(a)").collect()
- }.getMessage().contains("does not support sorting array of type array<int>"))
+ val df2 = Seq((Array[Array[Int]](Array(2), Array(1), Array(2, 4), null), "x")).toDF("a", "b")
+ checkAnswer(
+ df2.selectExpr("sort_array(a, true)", "sort_array(a, false)"),
+ Seq(
+ Row(
+ Seq[Seq[Int]](null, Seq(1), Seq(2), Seq(2, 4)),
+ Seq[Seq[Int]](Seq(2, 4), Seq(2), Seq(1), null)))
+ )
val df3 = Seq(("xxx", "x")).toDF("a", "b")
assert(intercept[AnalysisException] {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 61e3e913c2..6dde79f74d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -132,6 +132,22 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
(3, null, null)).toDF("key", "value1", "value2")
data2.write.saveAsTable("agg2")
+ val data3 = Seq[(Seq[Integer], Integer, Integer)](
+ (Seq[Integer](1, 1), 10, -10),
+ (Seq[Integer](null), -60, 60),
+ (Seq[Integer](1, 1), 30, -30),
+ (Seq[Integer](1), 30, 30),
+ (Seq[Integer](2), 1, 1),
+ (null, -10, 10),
+ (Seq[Integer](2, 3), -1, null),
+ (Seq[Integer](2, 3), 1, 1),
+ (Seq[Integer](2, 3, 4), null, 1),
+ (Seq[Integer](null), 100, -10),
+ (Seq[Integer](3), null, 3),
+ (null, null, null),
+ (Seq[Integer](3), null, null)).toDF("key", "value1", "value2")
+ data3.write.saveAsTable("agg3")
+
val emptyDF = sqlContext.createDataFrame(
sparkContext.emptyRDD[Row],
StructType(StructField("key", StringType) :: StructField("value", IntegerType) :: Nil))
@@ -146,6 +162,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
override def afterAll(): Unit = {
sqlContext.sql("DROP TABLE IF EXISTS agg1")
sqlContext.sql("DROP TABLE IF EXISTS agg2")
+ sqlContext.sql("DROP TABLE IF EXISTS agg3")
sqlContext.dropTempTable("emptyTable")
}
@@ -266,6 +283,41 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(100, null) ::
Row(null, 3) ::
Row(null, null) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT DISTINCT key
+ |FROM agg3
+ """.stripMargin),
+ Row(Seq[Integer](1, 1)) ::
+ Row(Seq[Integer](null)) ::
+ Row(Seq[Integer](1)) ::
+ Row(Seq[Integer](2)) ::
+ Row(null) ::
+ Row(Seq[Integer](2, 3)) ::
+ Row(Seq[Integer](2, 3, 4)) ::
+ Row(Seq[Integer](3)) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT value1, key
+ |FROM agg3
+ |GROUP BY value1, key
+ """.stripMargin),
+ Row(10, Seq[Integer](1, 1)) ::
+ Row(-60, Seq[Integer](null)) ::
+ Row(30, Seq[Integer](1, 1)) ::
+ Row(30, Seq[Integer](1)) ::
+ Row(1, Seq[Integer](2)) ::
+ Row(-10, null) ::
+ Row(-1, Seq[Integer](2, 3)) ::
+ Row(1, Seq[Integer](2, 3)) ::
+ Row(null, Seq[Integer](2, 3, 4)) ::
+ Row(100, Seq[Integer](null)) ::
+ Row(null, Seq[Integer](3)) ::
+ Row(null, null) :: Nil)
}
test("case in-sensitive resolution") {