diff options
author | Liang-Chi Hsieh <viirya@appier.com> | 2015-08-02 17:53:44 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-08-02 17:53:44 -0700 |
commit | 0722f43316fc7ed0c1308b0f9d6d15f0c22ed56f (patch) | |
tree | 2407fdf49ec8e8e6fd60ee0d98fd808952ab3b2f /sql/catalyst | |
parent | 2e981b7bfa9dec93fdcf25f3e7220cd6aaba744f (diff) | |
download | spark-0722f43316fc7ed0c1308b0f9d6d15f0c22ed56f.tar.gz spark-0722f43316fc7ed0c1308b0f9d6d15f0c22ed56f.tar.bz2 spark-0722f43316fc7ed0c1308b0f9d6d15f0c22ed56f.zip |
[SPARK-7937][SQL] Support comparison on StructType
This brings #6519 up-to-date with master branch.
Closes #6519.
Author: Liang-Chi Hsieh <viirya@appier.com>
Author: Liang-Chi Hsieh <viirya@gmail.com>
Author: Reynold Xin <rxin@databricks.com>
Closes #7877 from rxin/sort-struct and squashes the following commits:
4968231 [Reynold Xin] Minor fixes.
2537813 [Reynold Xin] Merge branch 'compare_named_struct' of github.com:viirya/spark-1 into sort-struct
d2ba8ad [Liang-Chi Hsieh] Remove unused import.
3a3f40e [Liang-Chi Hsieh] Don't need to add compare to InternalRow because we can use RowOrdering.
dae6aad [Liang-Chi Hsieh] Fix nested struct.
d5349c7 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into compare_named_struct
43d4354 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into compare_named_struct
1f66196 [Liang-Chi Hsieh] Reuse RowOrdering and GenerateOrdering.
f8b2e9c [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into compare_named_struct
1187a65 [Liang-Chi Hsieh] Fix scala style.
9d67f68 [Liang-Chi Hsieh] Fix wrongly merging.
8f4d775 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into compare_named_struct
94b27d5 [Liang-Chi Hsieh] Remove test for error on complex type comparison.
2071693 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into compare_named_struct
3c142e4 [Liang-Chi Hsieh] Fix scala style.
cf58dc3 [Liang-Chi Hsieh] Use checkAnswer.
f651b8d [Liang-Chi Hsieh] Remove Either and move orderings to BinaryComparison to reuse it.
b6e1009 [Liang-Chi Hsieh] Fix scala style.
3922b54 [Liang-Chi Hsieh] Support ordering on named_struct.
Diffstat (limited to 'sql/catalyst')
10 files changed, 93 insertions, 14 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 3177e6b750..3c91227d06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -80,6 +80,16 @@ class CodeGenContext { mutableStates += ((javaType, variableName, initCode)) } + /** + * Holding all the functions those will be added into generated class. + */ + val addedFuntions: mutable.Map[String, String] = + mutable.Map.empty[String, String] + + def addNewFunction(funcName: String, funcCode: String): Unit = { + addedFuntions += ((funcName, funcCode)) + } + final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" final val JAVA_SHORT = "short" @@ -221,6 +231,19 @@ class CodeGenContext { case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" case NullType => "0" + case schema: StructType if schema.supportOrdering(schema) => + val comparisons = GenerateOrdering.genComparisons(this, schema) + val compareFunc = freshName("compareStruct") + val funcCode: String = + s""" + public int $compareFunc(InternalRow a, InternalRow b) { + InternalRow i = null; + $comparisons + return 0; + } + """ + addNewFunction(compareFunc, funcCode) + s"this.$compareFunc($c1, $c2)" case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)" case _ => throw new IllegalArgumentException( "cannot generate compare code for un-comparable type") @@ -262,11 +285,15 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected def declareMutableStates(ctx: CodeGenContext): String = { ctx.mutableStates.map { case (javaType, variableName, _) => s"private $javaType $variableName;" - }.mkString("\n ") + }.mkString } protected def initMutableStates(ctx: CodeGenContext): String = { - ctx.mutableStates.map(_._3).mkString("\n ") + ctx.mutableStates.map(_._3).mkString + } + + protected def declareAddedFunctions(ctx: CodeGenContext): String = { + ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 825031a4fa..e4a8fc24da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -92,6 +92,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu private $exprType[] expressions; private $mutableRowType mutableRow; ${declareMutableStates(ctx)} + ${declareAddedFunctions(ctx)} public SpecificProjection($exprType[] expr) { expressions = expr; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index cc848aa199..4da91ed8d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -53,9 +53,21 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR }) } - protected def create(ordering: Seq[SortOrder]): BaseOrdering = { - val ctx = newCodeGenContext() + /** + * Generates the code for comparing a struct type according to its natural ordering + * (i.e. ascending order by field 1, then field 2, ..., then field n. + */ + def genComparisons(ctx: CodeGenContext, schema: StructType): String = { + val ordering = schema.fields.map(_.dataType).zipWithIndex.map { + case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) + } + genComparisons(ctx, ordering) + } + /** + * Generates the code for ordering based on the given order. + */ + def genComparisons(ctx: CodeGenContext, ordering: Seq[SortOrder]): String = { val comparisons = ordering.map { order => val eval = order.child.gen(ctx) val asc = order.direction == Ascending @@ -94,6 +106,12 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR } """ }.mkString("\n") + comparisons + } + + protected def create(ordering: Seq[SortOrder]): BaseOrdering = { + val ctx = newCodeGenContext() + val comparisons = genComparisons(ctx, ordering) val code = s""" public SpecificOrdering generate($exprType[] expr) { return new SpecificOrdering(expr); @@ -103,6 +121,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR private $exprType[] expressions; ${declareMutableStates(ctx)} + ${declareAddedFunctions(ctx)} public SpecificOrdering($exprType[] expr) { expressions = expr; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index dfd593fb7c..c7e718a526 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -48,6 +48,8 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool class SpecificPredicate extends ${classOf[Predicate].getName} { private final $exprType[] expressions; ${declareMutableStates(ctx)} + ${declareAddedFunctions(ctx)} + public SpecificPredicate($exprType[] expr) { expressions = expr; ${initMutableStates(ctx)} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 6f9acda071..1572b2b99a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -159,6 +159,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { class SpecificProjection extends ${classOf[BaseProjection].getName} { private $exprType[] expressions; ${declareMutableStates(ctx)} + ${declareAddedFunctions(ctx)} public SpecificProjection($exprType[] expr) { expressions = expr; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 6c99086046..934ec3f75c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -274,6 +274,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private $exprType[] expressions; ${declareMutableStates(ctx)} + ${declareAddedFunctions(ctx)} public SpecificUnsafeProjection($exprType[] expressions) { this.expressions = expressions; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 73f6b7a550..7e1031c755 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -145,6 +145,10 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { n.ordering.asInstanceOf[Ordering[Any]].compare(left, right) case n: AtomicType if order.direction == Descending => n.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) + case s: StructType if order.direction == Ascending => + s.ordering.asInstanceOf[Ordering[Any]].compare(left, right) + case s: StructType if order.direction == Descending => + s.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) case other => sys.error(s"Type $other does not support ordered operations") } if (comparison != 0) return comparison diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 0103ddcf9c..2f50d40fe2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -33,11 +33,18 @@ object TypeUtils { } def checkForOrderingExpr(t: DataType, caller: String): TypeCheckResult = { - if (t.isInstanceOf[AtomicType] || t == NullType) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure(s"$caller accepts non-complex types, not $t") + t match { + case i: AtomicType => TypeCheckResult.TypeCheckSuccess + case n: NullType => TypeCheckResult.TypeCheckSuccess + case s: StructType => + if (s.supportOrdering(s)) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"Fields in $s do not support ordering") + } + case other => TypeCheckResult.TypeCheckFailure(s"$t doesn't support ordering on $caller") } + } def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = { @@ -52,8 +59,12 @@ object TypeUtils { def getNumeric(t: DataType): Numeric[Any] = t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] - def getOrdering(t: DataType): Ordering[Any] = - t.asInstanceOf[AtomicType].ordering.asInstanceOf[Ordering[Any]] + def getOrdering(t: DataType): Ordering[Any] = { + t match { + case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] + case s: StructType => s.ordering.asInstanceOf[Ordering[Any]] + } + } def compareBinary(x: Array[Byte], y: Array[Byte]): Int = { for (i <- 0 until x.length; if i < y.length) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 2ef97a427c..2f23144858 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -24,7 +24,7 @@ import org.json4s.JsonDSL._ import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, RowOrdering} /** @@ -300,8 +300,21 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru StructType(newFields) } -} + private[sql] val ordering = RowOrdering.forSchema(this.fields.map(_.dataType)) + + private[sql] def supportOrdering(s: StructType): Boolean = { + s.fields.forall { f => + if (f.dataType.isInstanceOf[AtomicType]) { + true + } else if (f.dataType.isInstanceOf[StructType]) { + supportOrdering(f.dataType.asInstanceOf[StructType]) + } else { + false + } + } + } +} object StructType extends AbstractDataType { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index a52e4cb4df..8f616ae9d2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -145,8 +145,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(SumDistinct('stringField)) assertSuccess(Average('stringField)) - assertError(Min('complexField), "function min accepts non-complex type") - assertError(Max('complexField), "function max accepts non-complex type") + assertError(Min('complexField), "doesn't support ordering on function min") + assertError(Max('complexField), "doesn't support ordering on function max") assertError(Sum('booleanField), "function sum accepts numeric type") assertError(SumDistinct('booleanField), "function sumDistinct accepts numeric type") assertError(Average('booleanField), "function average accepts numeric type") |