diff options
author | Davies Liu <davies@databricks.com> | 2015-10-20 13:40:24 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2015-10-20 13:40:24 -0700 |
commit | 67d468f8d9172569ec9846edc6432240547696dd (patch) | |
tree | d9edf59834ee77c564b58ee12387db8ad17a9145 /sql/catalyst/src | |
parent | 478c7ce8628c05ebce2972e631d76317accebe9c (diff) | |
download | spark-67d468f8d9172569ec9846edc6432240547696dd.tar.gz spark-67d468f8d9172569ec9846edc6432240547696dd.tar.bz2 spark-67d468f8d9172569ec9846edc6432240547696dd.zip |
[SPARK-11111] [SQL] fast null-safe join
Currently, we use CartesianProduct for join with null-safe-equal condition.
```
scala> sqlContext.sql("select * from t a join t b on (a.i <=> b.i)").explain
== Physical Plan ==
TungstenProject [i#2,j#3,i#7,j#8]
Filter (i#2 <=> i#7)
CartesianProduct
LocalTableScan [i#2,j#3], [[1,1]]
LocalTableScan [i#7,j#8], [[1,1]]
```
Actually, we can have an equal-join condition as `coalesce(i, default) = coalesce(b.i, default)`, then an partitioned join algorithm could be used.
After this PR, the plan will become:
```
>>> sqlContext.sql("select * from a join b ON a.id <=> b.id").explain()
TungstenProject [id#0L,id#1L]
Filter (id#0L <=> id#1L)
SortMergeJoin [coalesce(id#0L,0)], [coalesce(id#1L,0)]
TungstenSort [coalesce(id#0L,0) ASC], false, 0
TungstenExchange hashpartitioning(coalesce(id#0L,0),200)
ConvertToUnsafe
Scan PhysicalRDD[id#0L]
TungstenSort [coalesce(id#1L,0) ASC], false, 0
TungstenExchange hashpartitioning(coalesce(id#1L,0),200)
ConvertToUnsafe
Scan PhysicalRDD[id#1L]
```
Author: Davies Liu <davies@databricks.com>
Closes #9120 from davies/null_safe.
Diffstat (limited to 'sql/catalyst/src')
3 files changed, 77 insertions, 15 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 51be819e9d..455fa2427c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -52,6 +51,32 @@ object Literal { def create(v: Any, dataType: DataType): Literal = { Literal(CatalystTypeConverters.convertToCatalyst(v), dataType) } + + /** + * Create a literal with default value for given DataType + */ + def default(dataType: DataType): Literal = dataType match { + case NullType => create(null, NullType) + case BooleanType => Literal(false) + case ByteType => Literal(0.toByte) + case ShortType => Literal(0.toShort) + case IntegerType => Literal(0) + case LongType => Literal(0L) + case FloatType => Literal(0.0f) + case DoubleType => Literal(0.0) + case dt: DecimalType => Literal(Decimal(0, dt.precision, dt.scale)) + case DateType => create(0, DateType) + case TimestampType => create(0L, TimestampType) + case StringType => Literal("") + case BinaryType => Literal("".getBytes) + case CalendarIntervalType => Literal(new CalendarInterval(0, 0)) + case arr: ArrayType => create(Array(), arr) + case map: MapType => create(Map(), map) + case struct: StructType => + create(InternalRow.fromSeq(struct.fields.map(f => default(f.dataType).value)), struct) + case other => + throw new RuntimeException(s"no default for type $dataType") + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 5353779951..3b975b904a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.catalyst.planning import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.TreeNodeRef /** * A pattern that matches any number of project or filter operations on top of another relational @@ -160,6 +160,9 @@ object PartialAggregation { /** * A pattern that finds joins with equality conditions that can be evaluated using equi-join. + * + * Null-safe equality will be transformed into equality as joining key (replace null with default + * value). */ object ExtractEquiJoinKeys extends Logging with PredicateHelper { /** (joinType, leftKeys, rightKeys, condition, leftChild, rightChild) */ @@ -171,17 +174,25 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { logDebug(s"Considering join on: $condition") // Find equi-join predicates that can be evaluated before the join, and thus can be used // as join keys. - val (joinPredicates, otherPredicates) = - condition.map(splitConjunctivePredicates).getOrElse(Nil).partition { - case EqualTo(l, r) => - (canEvaluate(l, left) && canEvaluate(r, right)) || - (canEvaluate(l, right) && canEvaluate(r, left)) - case _ => false - } - - val joinKeys = joinPredicates.map { - case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => (l, r) - case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => (r, l) + val predicates = condition.map(splitConjunctivePredicates).getOrElse(Nil) + val joinKeys = predicates.flatMap { + case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => Some((l, r)) + case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => Some((r, l)) + // Replace null with default value for joining key, then those rows with null in it could + // be joined together + case EqualNullSafe(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => + Some((Coalesce(Seq(l, Literal.default(l.dataType))), + Coalesce(Seq(r, Literal.default(r.dataType))))) + case EqualNullSafe(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => + Some((Coalesce(Seq(r, Literal.default(r.dataType))), + Coalesce(Seq(l, Literal.default(l.dataType))))) + case other => None + } + val otherPredicates = predicates.filterNot { + case EqualTo(l, r) => + canEvaluate(l, left) && canEvaluate(r, right) || + canEvaluate(l, right) && canEvaluate(r, left) + case other => false } if (joinKeys.nonEmpty) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index 015eb1897f..7b85286c4d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -18,7 +18,10 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -30,15 +33,38 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, IntegerType), null) checkEvaluation(Literal.create(null, LongType), null) checkEvaluation(Literal.create(null, FloatType), null) - checkEvaluation(Literal.create(null, LongType), null) + checkEvaluation(Literal.create(null, DoubleType), null) checkEvaluation(Literal.create(null, StringType), null) checkEvaluation(Literal.create(null, BinaryType), null) checkEvaluation(Literal.create(null, DecimalType.USER_DEFAULT), null) + checkEvaluation(Literal.create(null, DateType), null) + checkEvaluation(Literal.create(null, TimestampType), null) + checkEvaluation(Literal.create(null, CalendarIntervalType), null) checkEvaluation(Literal.create(null, ArrayType(ByteType, true)), null) checkEvaluation(Literal.create(null, MapType(StringType, IntegerType)), null) checkEvaluation(Literal.create(null, StructType(Seq.empty)), null) } + test("default") { + checkEvaluation(Literal.default(BooleanType), false) + checkEvaluation(Literal.default(ByteType), 0.toByte) + checkEvaluation(Literal.default(ShortType), 0.toShort) + checkEvaluation(Literal.default(IntegerType), 0) + checkEvaluation(Literal.default(LongType), 0L) + checkEvaluation(Literal.default(FloatType), 0.0f) + checkEvaluation(Literal.default(DoubleType), 0.0) + checkEvaluation(Literal.default(StringType), "") + checkEvaluation(Literal.default(BinaryType), "".getBytes) + checkEvaluation(Literal.default(DecimalType.USER_DEFAULT), Decimal(0)) + checkEvaluation(Literal.default(DecimalType.SYSTEM_DEFAULT), Decimal(0)) + checkEvaluation(Literal.default(DateType), DateTimeUtils.toJavaDate(0)) + checkEvaluation(Literal.default(TimestampType), DateTimeUtils.toJavaTimestamp(0L)) + checkEvaluation(Literal.default(CalendarIntervalType), new CalendarInterval(0, 0L)) + checkEvaluation(Literal.default(ArrayType(StringType)), Array()) + checkEvaluation(Literal.default(MapType(IntegerType, StringType)), Map()) + checkEvaluation(Literal.default(StructType(StructField("a", StringType) :: Nil)), Row("")) + } + test("boolean literals") { checkEvaluation(Literal(true), true) checkEvaluation(Literal(false), false) |