aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@appier.com>2015-08-02 17:53:44 -0700
committerReynold Xin <rxin@databricks.com>2015-08-02 17:53:44 -0700
commit0722f43316fc7ed0c1308b0f9d6d15f0c22ed56f (patch)
tree2407fdf49ec8e8e6fd60ee0d98fd808952ab3b2f
parent2e981b7bfa9dec93fdcf25f3e7220cd6aaba744f (diff)
downloadspark-0722f43316fc7ed0c1308b0f9d6d15f0c22ed56f.tar.gz
spark-0722f43316fc7ed0c1308b0f9d6d15f0c22ed56f.tar.bz2
spark-0722f43316fc7ed0c1308b0f9d6d15f0c22ed56f.zip
[SPARK-7937][SQL] Support comparison on StructType
This brings #6519 up-to-date with master branch. Closes #6519. Author: Liang-Chi Hsieh <viirya@appier.com> Author: Liang-Chi Hsieh <viirya@gmail.com> Author: Reynold Xin <rxin@databricks.com> Closes #7877 from rxin/sort-struct and squashes the following commits: 4968231 [Reynold Xin] Minor fixes. 2537813 [Reynold Xin] Merge branch 'compare_named_struct' of github.com:viirya/spark-1 into sort-struct d2ba8ad [Liang-Chi Hsieh] Remove unused import. 3a3f40e [Liang-Chi Hsieh] Don't need to add compare to InternalRow because we can use RowOrdering. dae6aad [Liang-Chi Hsieh] Fix nested struct. d5349c7 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into compare_named_struct 43d4354 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into compare_named_struct 1f66196 [Liang-Chi Hsieh] Reuse RowOrdering and GenerateOrdering. f8b2e9c [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into compare_named_struct 1187a65 [Liang-Chi Hsieh] Fix scala style. 9d67f68 [Liang-Chi Hsieh] Fix wrongly merging. 8f4d775 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into compare_named_struct 94b27d5 [Liang-Chi Hsieh] Remove test for error on complex type comparison. 2071693 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into compare_named_struct 3c142e4 [Liang-Chi Hsieh] Fix scala style. cf58dc3 [Liang-Chi Hsieh] Use checkAnswer. f651b8d [Liang-Chi Hsieh] Remove Either and move orderings to BinaryComparison to reuse it. b6e1009 [Liang-Chi Hsieh] Fix scala style. 3922b54 [Liang-Chi Hsieh] Support ordering on named_struct.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala31
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala23
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala23
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala17
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala43
11 files changed, 135 insertions, 15 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 3177e6b750..3c91227d06 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -80,6 +80,16 @@ class CodeGenContext {
mutableStates += ((javaType, variableName, initCode))
}
+ /**
+ * Holding all the functions those will be added into generated class.
+ */
+ val addedFuntions: mutable.Map[String, String] =
+ mutable.Map.empty[String, String]
+
+ def addNewFunction(funcName: String, funcCode: String): Unit = {
+ addedFuntions += ((funcName, funcCode))
+ }
+
final val JAVA_BOOLEAN = "boolean"
final val JAVA_BYTE = "byte"
final val JAVA_SHORT = "short"
@@ -221,6 +231,19 @@ class CodeGenContext {
case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
case NullType => "0"
+ case schema: StructType if schema.supportOrdering(schema) =>
+ val comparisons = GenerateOrdering.genComparisons(this, schema)
+ val compareFunc = freshName("compareStruct")
+ val funcCode: String =
+ s"""
+ public int $compareFunc(InternalRow a, InternalRow b) {
+ InternalRow i = null;
+ $comparisons
+ return 0;
+ }
+ """
+ addNewFunction(compareFunc, funcCode)
+ s"this.$compareFunc($c1, $c2)"
case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)"
case _ => throw new IllegalArgumentException(
"cannot generate compare code for un-comparable type")
@@ -262,11 +285,15 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
protected def declareMutableStates(ctx: CodeGenContext): String = {
ctx.mutableStates.map { case (javaType, variableName, _) =>
s"private $javaType $variableName;"
- }.mkString("\n ")
+ }.mkString
}
protected def initMutableStates(ctx: CodeGenContext): String = {
- ctx.mutableStates.map(_._3).mkString("\n ")
+ ctx.mutableStates.map(_._3).mkString
+ }
+
+ protected def declareAddedFunctions(ctx: CodeGenContext): String = {
+ ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index 825031a4fa..e4a8fc24da 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -92,6 +92,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
private $exprType[] expressions;
private $mutableRowType mutableRow;
${declareMutableStates(ctx)}
+ ${declareAddedFunctions(ctx)}
public SpecificProjection($exprType[] expr) {
expressions = expr;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
index cc848aa199..4da91ed8d7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -53,9 +53,21 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
})
}
- protected def create(ordering: Seq[SortOrder]): BaseOrdering = {
- val ctx = newCodeGenContext()
+ /**
+ * Generates the code for comparing a struct type according to its natural ordering
+ * (i.e. ascending order by field 1, then field 2, ..., then field n.
+ */
+ def genComparisons(ctx: CodeGenContext, schema: StructType): String = {
+ val ordering = schema.fields.map(_.dataType).zipWithIndex.map {
+ case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending)
+ }
+ genComparisons(ctx, ordering)
+ }
+ /**
+ * Generates the code for ordering based on the given order.
+ */
+ def genComparisons(ctx: CodeGenContext, ordering: Seq[SortOrder]): String = {
val comparisons = ordering.map { order =>
val eval = order.child.gen(ctx)
val asc = order.direction == Ascending
@@ -94,6 +106,12 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
}
"""
}.mkString("\n")
+ comparisons
+ }
+
+ protected def create(ordering: Seq[SortOrder]): BaseOrdering = {
+ val ctx = newCodeGenContext()
+ val comparisons = genComparisons(ctx, ordering)
val code = s"""
public SpecificOrdering generate($exprType[] expr) {
return new SpecificOrdering(expr);
@@ -103,6 +121,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
private $exprType[] expressions;
${declareMutableStates(ctx)}
+ ${declareAddedFunctions(ctx)}
public SpecificOrdering($exprType[] expr) {
expressions = expr;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
index dfd593fb7c..c7e718a526 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
@@ -48,6 +48,8 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool
class SpecificPredicate extends ${classOf[Predicate].getName} {
private final $exprType[] expressions;
${declareMutableStates(ctx)}
+ ${declareAddedFunctions(ctx)}
+
public SpecificPredicate($exprType[] expr) {
expressions = expr;
${initMutableStates(ctx)}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index 6f9acda071..1572b2b99a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -159,6 +159,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
class SpecificProjection extends ${classOf[BaseProjection].getName} {
private $exprType[] expressions;
${declareMutableStates(ctx)}
+ ${declareAddedFunctions(ctx)}
public SpecificProjection($exprType[] expr) {
expressions = expr;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 6c99086046..934ec3f75c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -274,6 +274,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
private $exprType[] expressions;
${declareMutableStates(ctx)}
+ ${declareAddedFunctions(ctx)}
public SpecificUnsafeProjection($exprType[] expressions) {
this.expressions = expressions;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index 73f6b7a550..7e1031c755 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -145,6 +145,10 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] {
n.ordering.asInstanceOf[Ordering[Any]].compare(left, right)
case n: AtomicType if order.direction == Descending =>
n.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
+ case s: StructType if order.direction == Ascending =>
+ s.ordering.asInstanceOf[Ordering[Any]].compare(left, right)
+ case s: StructType if order.direction == Descending =>
+ s.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
case other => sys.error(s"Type $other does not support ordered operations")
}
if (comparison != 0) return comparison
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
index 0103ddcf9c..2f50d40fe2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
@@ -33,11 +33,18 @@ object TypeUtils {
}
def checkForOrderingExpr(t: DataType, caller: String): TypeCheckResult = {
- if (t.isInstanceOf[AtomicType] || t == NullType) {
- TypeCheckResult.TypeCheckSuccess
- } else {
- TypeCheckResult.TypeCheckFailure(s"$caller accepts non-complex types, not $t")
+ t match {
+ case i: AtomicType => TypeCheckResult.TypeCheckSuccess
+ case n: NullType => TypeCheckResult.TypeCheckSuccess
+ case s: StructType =>
+ if (s.supportOrdering(s)) {
+ TypeCheckResult.TypeCheckSuccess
+ } else {
+ TypeCheckResult.TypeCheckFailure(s"Fields in $s do not support ordering")
+ }
+ case other => TypeCheckResult.TypeCheckFailure(s"$t doesn't support ordering on $caller")
}
+
}
def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = {
@@ -52,8 +59,12 @@ object TypeUtils {
def getNumeric(t: DataType): Numeric[Any] =
t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]]
- def getOrdering(t: DataType): Ordering[Any] =
- t.asInstanceOf[AtomicType].ordering.asInstanceOf[Ordering[Any]]
+ def getOrdering(t: DataType): Ordering[Any] = {
+ t match {
+ case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
+ case s: StructType => s.ordering.asInstanceOf[Ordering[Any]]
+ }
+ }
def compareBinary(x: Array[Byte], y: Array[Byte]): Int = {
for (i <- 0 until x.length; if i < y.length) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index 2ef97a427c..2f23144858 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -24,7 +24,7 @@ import org.json4s.JsonDSL._
import org.apache.spark.SparkException
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, RowOrdering}
/**
@@ -300,8 +300,21 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
StructType(newFields)
}
-}
+ private[sql] val ordering = RowOrdering.forSchema(this.fields.map(_.dataType))
+
+ private[sql] def supportOrdering(s: StructType): Boolean = {
+ s.fields.forall { f =>
+ if (f.dataType.isInstanceOf[AtomicType]) {
+ true
+ } else if (f.dataType.isInstanceOf[StructType]) {
+ supportOrdering(f.dataType.asInstanceOf[StructType])
+ } else {
+ false
+ }
+ }
+ }
+}
object StructType extends AbstractDataType {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index a52e4cb4df..8f616ae9d2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -145,8 +145,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertSuccess(SumDistinct('stringField))
assertSuccess(Average('stringField))
- assertError(Min('complexField), "function min accepts non-complex type")
- assertError(Max('complexField), "function max accepts non-complex type")
+ assertError(Min('complexField), "doesn't support ordering on function min")
+ assertError(Max('complexField), "doesn't support ordering on function max")
assertError(Sum('booleanField), "function sum accepts numeric type")
assertError(SumDistinct('booleanField), "function sumDistinct accepts numeric type")
assertError(Average('booleanField), "function average accepts numeric type")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index 1bde5922b5..7069afc9f7 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -28,7 +28,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory}
import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats}
import org.apache.hadoop.io.Writable
-import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
+import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf}
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.util.Utils
@@ -93,6 +93,47 @@ class HiveUDFSuite extends QueryTest {
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF")
}
+ test("Max/Min on named_struct") {
+ def testOrderInStruct(): Unit = {
+ checkAnswer(sql(
+ """
+ |SELECT max(named_struct(
+ | "key", key,
+ | "value", value)).value FROM src
+ """.stripMargin), Seq(Row("val_498")))
+ checkAnswer(sql(
+ """
+ |SELECT min(named_struct(
+ | "key", key,
+ | "value", value)).value FROM src
+ """.stripMargin), Seq(Row("val_0")))
+
+ // nested struct cases
+ checkAnswer(sql(
+ """
+ |SELECT max(named_struct(
+ | "key", named_struct(
+ "key", key,
+ "value", value),
+ | "value", value)).value FROM src
+ """.stripMargin), Seq(Row("val_498")))
+ checkAnswer(sql(
+ """
+ |SELECT min(named_struct(
+ | "key", named_struct(
+ "key", key,
+ "value", value),
+ | "value", value)).value FROM src
+ """.stripMargin), Seq(Row("val_0")))
+ }
+ val codegenDefault = TestHive.getConf(SQLConf.CODEGEN_ENABLED)
+ TestHive.setConf(SQLConf.CODEGEN_ENABLED, true)
+ testOrderInStruct()
+ TestHive.setConf(SQLConf.CODEGEN_ENABLED, false)
+ testOrderInStruct()
+ TestHive.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault)
+ }
+
test("SPARK-6409 UDAFAverage test") {
sql(s"CREATE TEMPORARY FUNCTION test_avg AS '${classOf[GenericUDAFAverage].getName}'")
checkAnswer(