From 67d468f8d9172569ec9846edc6432240547696dd Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 20 Oct 2015 13:40:24 -0700 Subject: [SPARK-11111] [SQL] fast null-safe join Currently, we use CartesianProduct for join with null-safe-equal condition. ``` scala> sqlContext.sql("select * from t a join t b on (a.i <=> b.i)").explain == Physical Plan == TungstenProject [i#2,j#3,i#7,j#8] Filter (i#2 <=> i#7) CartesianProduct LocalTableScan [i#2,j#3], [[1,1]] LocalTableScan [i#7,j#8], [[1,1]] ``` Actually, we can have an equal-join condition as `coalesce(i, default) = coalesce(b.i, default)`, then an partitioned join algorithm could be used. After this PR, the plan will become: ``` >>> sqlContext.sql("select * from a join b ON a.id <=> b.id").explain() TungstenProject [id#0L,id#1L] Filter (id#0L <=> id#1L) SortMergeJoin [coalesce(id#0L,0)], [coalesce(id#1L,0)] TungstenSort [coalesce(id#0L,0) ASC], false, 0 TungstenExchange hashpartitioning(coalesce(id#0L,0),200) ConvertToUnsafe Scan PhysicalRDD[id#0L] TungstenSort [coalesce(id#1L,0) ASC], false, 0 TungstenExchange hashpartitioning(coalesce(id#1L,0),200) ConvertToUnsafe Scan PhysicalRDD[id#1L] ``` Author: Davies Liu Closes #9120 from davies/null_safe. --- .../spark/sql/catalyst/expressions/literals.scala | 29 ++++++++++++++++-- .../spark/sql/catalyst/planning/patterns.scala | 35 ++++++++++++++-------- .../expressions/LiteralExpressionSuite.scala | 28 ++++++++++++++++- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 14 +++++++++ .../spark/sql/execution/joins/InnerJoinSuite.scala | 14 +++++++++ 5 files changed, 105 insertions(+), 15 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 51be819e9d..455fa2427c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -52,6 +51,32 @@ object Literal { def create(v: Any, dataType: DataType): Literal = { Literal(CatalystTypeConverters.convertToCatalyst(v), dataType) } + + /** + * Create a literal with default value for given DataType + */ + def default(dataType: DataType): Literal = dataType match { + case NullType => create(null, NullType) + case BooleanType => Literal(false) + case ByteType => Literal(0.toByte) + case ShortType => Literal(0.toShort) + case IntegerType => Literal(0) + case LongType => Literal(0L) + case FloatType => Literal(0.0f) + case DoubleType => Literal(0.0) + case dt: DecimalType => Literal(Decimal(0, dt.precision, dt.scale)) + case DateType => create(0, DateType) + case TimestampType => create(0L, TimestampType) + case StringType => Literal("") + case BinaryType => Literal("".getBytes) + case CalendarIntervalType => Literal(new CalendarInterval(0, 0)) + case arr: ArrayType => create(Array(), arr) + case map: MapType => create(Map(), map) + case struct: StructType => + create(InternalRow.fromSeq(struct.fields.map(f => default(f.dataType).value)), struct) + case other => + throw new RuntimeException(s"no default for type $dataType") + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 5353779951..3b975b904a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.catalyst.planning import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.TreeNodeRef /** * A pattern that matches any number of project or filter operations on top of another relational @@ -160,6 +160,9 @@ object PartialAggregation { /** * A pattern that finds joins with equality conditions that can be evaluated using equi-join. + * + * Null-safe equality will be transformed into equality as joining key (replace null with default + * value). */ object ExtractEquiJoinKeys extends Logging with PredicateHelper { /** (joinType, leftKeys, rightKeys, condition, leftChild, rightChild) */ @@ -171,17 +174,25 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { logDebug(s"Considering join on: $condition") // Find equi-join predicates that can be evaluated before the join, and thus can be used // as join keys. - val (joinPredicates, otherPredicates) = - condition.map(splitConjunctivePredicates).getOrElse(Nil).partition { - case EqualTo(l, r) => - (canEvaluate(l, left) && canEvaluate(r, right)) || - (canEvaluate(l, right) && canEvaluate(r, left)) - case _ => false - } - - val joinKeys = joinPredicates.map { - case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => (l, r) - case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => (r, l) + val predicates = condition.map(splitConjunctivePredicates).getOrElse(Nil) + val joinKeys = predicates.flatMap { + case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => Some((l, r)) + case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => Some((r, l)) + // Replace null with default value for joining key, then those rows with null in it could + // be joined together + case EqualNullSafe(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => + Some((Coalesce(Seq(l, Literal.default(l.dataType))), + Coalesce(Seq(r, Literal.default(r.dataType))))) + case EqualNullSafe(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => + Some((Coalesce(Seq(r, Literal.default(r.dataType))), + Coalesce(Seq(l, Literal.default(l.dataType))))) + case other => None + } + val otherPredicates = predicates.filterNot { + case EqualTo(l, r) => + canEvaluate(l, left) && canEvaluate(r, right) || + canEvaluate(l, right) && canEvaluate(r, left) + case other => false } if (joinKeys.nonEmpty) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index 015eb1897f..7b85286c4d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -18,7 +18,10 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -30,15 +33,38 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, IntegerType), null) checkEvaluation(Literal.create(null, LongType), null) checkEvaluation(Literal.create(null, FloatType), null) - checkEvaluation(Literal.create(null, LongType), null) + checkEvaluation(Literal.create(null, DoubleType), null) checkEvaluation(Literal.create(null, StringType), null) checkEvaluation(Literal.create(null, BinaryType), null) checkEvaluation(Literal.create(null, DecimalType.USER_DEFAULT), null) + checkEvaluation(Literal.create(null, DateType), null) + checkEvaluation(Literal.create(null, TimestampType), null) + checkEvaluation(Literal.create(null, CalendarIntervalType), null) checkEvaluation(Literal.create(null, ArrayType(ByteType, true)), null) checkEvaluation(Literal.create(null, MapType(StringType, IntegerType)), null) checkEvaluation(Literal.create(null, StructType(Seq.empty)), null) } + test("default") { + checkEvaluation(Literal.default(BooleanType), false) + checkEvaluation(Literal.default(ByteType), 0.toByte) + checkEvaluation(Literal.default(ShortType), 0.toShort) + checkEvaluation(Literal.default(IntegerType), 0) + checkEvaluation(Literal.default(LongType), 0L) + checkEvaluation(Literal.default(FloatType), 0.0f) + checkEvaluation(Literal.default(DoubleType), 0.0) + checkEvaluation(Literal.default(StringType), "") + checkEvaluation(Literal.default(BinaryType), "".getBytes) + checkEvaluation(Literal.default(DecimalType.USER_DEFAULT), Decimal(0)) + checkEvaluation(Literal.default(DecimalType.SYSTEM_DEFAULT), Decimal(0)) + checkEvaluation(Literal.default(DateType), DateTimeUtils.toJavaDate(0)) + checkEvaluation(Literal.default(TimestampType), DateTimeUtils.toJavaTimestamp(0L)) + checkEvaluation(Literal.default(CalendarIntervalType), new CalendarInterval(0, 0L)) + checkEvaluation(Literal.default(ArrayType(StringType)), Array()) + checkEvaluation(Literal.default(MapType(IntegerType, StringType)), Map()) + checkEvaluation(Literal.default(StructType(StructField("a", StringType) :: Nil)), Row("")) + } + test("boolean literals") { checkEvaluation(Literal(true), true) checkEvaluation(Literal(false), false) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 636591630e..a35a7f41dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.execution.joins.{SortMergeJoin, CartesianProduct} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} @@ -850,6 +851,19 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(null, null, 6, "F") :: Nil) } + test("SPARK-11111 null-safe join should not use cartesian product") { + val df = sql("select count(*) from testData a join testData b on (a.key <=> b.key)") + val cp = df.queryExecution.executedPlan.collect { + case cp: CartesianProduct => cp + } + assert(cp.isEmpty, "should not use CartesianProduct for null-safe join") + val smj = df.queryExecution.executedPlan.collect { + case smj: SortMergeJoin => smj + } + assert(smj.size > 0, "should use SortMergeJoin") + checkAnswer(df, Row(100) :: Nil) + } + test("SPARK-3349 partitioning after limit") { sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n DESC") .limit(2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 4174ee0550..da58e96f3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -212,4 +212,18 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { ) } + { + lazy val left = Seq((1, Some(0)), (2, None)).toDF("a", "b") + lazy val right = Seq((1, Some(0)), (2, None)).toDF("a", "b") + testInnerJoin( + "inner join, null safe", + left, + right, + () => (left.col("b") <=> right.col("b")).expr, + Seq( + (1, 0, 1, 0), + (2, null, 2, null) + ) + ) + } } -- cgit v1.2.3