aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-10-20 13:40:24 -0700
committerDavies Liu <davies.liu@gmail.com>2015-10-20 13:40:24 -0700
commit67d468f8d9172569ec9846edc6432240547696dd (patch)
treed9edf59834ee77c564b58ee12387db8ad17a9145 /sql
parent478c7ce8628c05ebce2972e631d76317accebe9c (diff)
downloadspark-67d468f8d9172569ec9846edc6432240547696dd.tar.gz
spark-67d468f8d9172569ec9846edc6432240547696dd.tar.bz2
spark-67d468f8d9172569ec9846edc6432240547696dd.zip
[SPARK-11111] [SQL] fast null-safe join
Currently, we use CartesianProduct for join with null-safe-equal condition. ``` scala> sqlContext.sql("select * from t a join t b on (a.i <=> b.i)").explain == Physical Plan == TungstenProject [i#2,j#3,i#7,j#8] Filter (i#2 <=> i#7) CartesianProduct LocalTableScan [i#2,j#3], [[1,1]] LocalTableScan [i#7,j#8], [[1,1]] ``` Actually, we can have an equal-join condition as `coalesce(i, default) = coalesce(b.i, default)`, then an partitioned join algorithm could be used. After this PR, the plan will become: ``` >>> sqlContext.sql("select * from a join b ON a.id <=> b.id").explain() TungstenProject [id#0L,id#1L] Filter (id#0L <=> id#1L) SortMergeJoin [coalesce(id#0L,0)], [coalesce(id#1L,0)] TungstenSort [coalesce(id#0L,0) ASC], false, 0 TungstenExchange hashpartitioning(coalesce(id#0L,0),200) ConvertToUnsafe Scan PhysicalRDD[id#0L] TungstenSort [coalesce(id#1L,0) ASC], false, 0 TungstenExchange hashpartitioning(coalesce(id#1L,0),200) ConvertToUnsafe Scan PhysicalRDD[id#1L] ``` Author: Davies Liu <davies@databricks.com> Closes #9120 from davies/null_safe.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala29
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala35
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala28
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala14
5 files changed, 105 insertions, 15 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 51be819e9d..455fa2427c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.sql.{Date, Timestamp}
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
@@ -52,6 +51,32 @@ object Literal {
def create(v: Any, dataType: DataType): Literal = {
Literal(CatalystTypeConverters.convertToCatalyst(v), dataType)
}
+
+ /**
+ * Create a literal with default value for given DataType
+ */
+ def default(dataType: DataType): Literal = dataType match {
+ case NullType => create(null, NullType)
+ case BooleanType => Literal(false)
+ case ByteType => Literal(0.toByte)
+ case ShortType => Literal(0.toShort)
+ case IntegerType => Literal(0)
+ case LongType => Literal(0L)
+ case FloatType => Literal(0.0f)
+ case DoubleType => Literal(0.0)
+ case dt: DecimalType => Literal(Decimal(0, dt.precision, dt.scale))
+ case DateType => create(0, DateType)
+ case TimestampType => create(0L, TimestampType)
+ case StringType => Literal("")
+ case BinaryType => Literal("".getBytes)
+ case CalendarIntervalType => Literal(new CalendarInterval(0, 0))
+ case arr: ArrayType => create(Array(), arr)
+ case map: MapType => create(Map(), map)
+ case struct: StructType =>
+ create(InternalRow.fromSeq(struct.fields.map(f => default(f.dataType).value)), struct)
+ case other =>
+ throw new RuntimeException(s"no default for type $dataType")
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 5353779951..3b975b904a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -18,10 +18,10 @@
package org.apache.spark.sql.catalyst.planning
import org.apache.spark.Logging
-import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.trees.TreeNodeRef
/**
* A pattern that matches any number of project or filter operations on top of another relational
@@ -160,6 +160,9 @@ object PartialAggregation {
/**
* A pattern that finds joins with equality conditions that can be evaluated using equi-join.
+ *
+ * Null-safe equality will be transformed into equality as joining key (replace null with default
+ * value).
*/
object ExtractEquiJoinKeys extends Logging with PredicateHelper {
/** (joinType, leftKeys, rightKeys, condition, leftChild, rightChild) */
@@ -171,17 +174,25 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
logDebug(s"Considering join on: $condition")
// Find equi-join predicates that can be evaluated before the join, and thus can be used
// as join keys.
- val (joinPredicates, otherPredicates) =
- condition.map(splitConjunctivePredicates).getOrElse(Nil).partition {
- case EqualTo(l, r) =>
- (canEvaluate(l, left) && canEvaluate(r, right)) ||
- (canEvaluate(l, right) && canEvaluate(r, left))
- case _ => false
- }
-
- val joinKeys = joinPredicates.map {
- case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => (l, r)
- case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => (r, l)
+ val predicates = condition.map(splitConjunctivePredicates).getOrElse(Nil)
+ val joinKeys = predicates.flatMap {
+ case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => Some((l, r))
+ case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => Some((r, l))
+ // Replace null with default value for joining key, then those rows with null in it could
+ // be joined together
+ case EqualNullSafe(l, r) if canEvaluate(l, left) && canEvaluate(r, right) =>
+ Some((Coalesce(Seq(l, Literal.default(l.dataType))),
+ Coalesce(Seq(r, Literal.default(r.dataType)))))
+ case EqualNullSafe(l, r) if canEvaluate(l, right) && canEvaluate(r, left) =>
+ Some((Coalesce(Seq(r, Literal.default(r.dataType))),
+ Coalesce(Seq(l, Literal.default(l.dataType)))))
+ case other => None
+ }
+ val otherPredicates = predicates.filterNot {
+ case EqualTo(l, r) =>
+ canEvaluate(l, left) && canEvaluate(r, right) ||
+ canEvaluate(l, right) && canEvaluate(r, left)
+ case other => false
}
if (joinKeys.nonEmpty) {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
index 015eb1897f..7b85286c4d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
@@ -18,7 +18,10 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -30,15 +33,38 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Literal.create(null, IntegerType), null)
checkEvaluation(Literal.create(null, LongType), null)
checkEvaluation(Literal.create(null, FloatType), null)
- checkEvaluation(Literal.create(null, LongType), null)
+ checkEvaluation(Literal.create(null, DoubleType), null)
checkEvaluation(Literal.create(null, StringType), null)
checkEvaluation(Literal.create(null, BinaryType), null)
checkEvaluation(Literal.create(null, DecimalType.USER_DEFAULT), null)
+ checkEvaluation(Literal.create(null, DateType), null)
+ checkEvaluation(Literal.create(null, TimestampType), null)
+ checkEvaluation(Literal.create(null, CalendarIntervalType), null)
checkEvaluation(Literal.create(null, ArrayType(ByteType, true)), null)
checkEvaluation(Literal.create(null, MapType(StringType, IntegerType)), null)
checkEvaluation(Literal.create(null, StructType(Seq.empty)), null)
}
+ test("default") {
+ checkEvaluation(Literal.default(BooleanType), false)
+ checkEvaluation(Literal.default(ByteType), 0.toByte)
+ checkEvaluation(Literal.default(ShortType), 0.toShort)
+ checkEvaluation(Literal.default(IntegerType), 0)
+ checkEvaluation(Literal.default(LongType), 0L)
+ checkEvaluation(Literal.default(FloatType), 0.0f)
+ checkEvaluation(Literal.default(DoubleType), 0.0)
+ checkEvaluation(Literal.default(StringType), "")
+ checkEvaluation(Literal.default(BinaryType), "".getBytes)
+ checkEvaluation(Literal.default(DecimalType.USER_DEFAULT), Decimal(0))
+ checkEvaluation(Literal.default(DecimalType.SYSTEM_DEFAULT), Decimal(0))
+ checkEvaluation(Literal.default(DateType), DateTimeUtils.toJavaDate(0))
+ checkEvaluation(Literal.default(TimestampType), DateTimeUtils.toJavaTimestamp(0L))
+ checkEvaluation(Literal.default(CalendarIntervalType), new CalendarInterval(0, 0L))
+ checkEvaluation(Literal.default(ArrayType(StringType)), Array())
+ checkEvaluation(Literal.default(MapType(IntegerType, StringType)), Map())
+ checkEvaluation(Literal.default(StructType(StructField("a", StringType) :: Nil)), Row(""))
+ }
+
test("boolean literals") {
checkEvaluation(Literal(true), true)
checkEvaluation(Literal(false), false)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 636591630e..a35a7f41dd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.DefaultParserDialect
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.execution.aggregate
+import org.apache.spark.sql.execution.joins.{SortMergeJoin, CartesianProduct}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SQLTestData._
import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext}
@@ -850,6 +851,19 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
Row(null, null, 6, "F") :: Nil)
}
+ test("SPARK-11111 null-safe join should not use cartesian product") {
+ val df = sql("select count(*) from testData a join testData b on (a.key <=> b.key)")
+ val cp = df.queryExecution.executedPlan.collect {
+ case cp: CartesianProduct => cp
+ }
+ assert(cp.isEmpty, "should not use CartesianProduct for null-safe join")
+ val smj = df.queryExecution.executedPlan.collect {
+ case smj: SortMergeJoin => smj
+ }
+ assert(smj.size > 0, "should use SortMergeJoin")
+ checkAnswer(df, Row(100) :: Nil)
+ }
+
test("SPARK-3349 partitioning after limit") {
sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n DESC")
.limit(2)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index 4174ee0550..da58e96f3e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -212,4 +212,18 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
)
}
+ {
+ lazy val left = Seq((1, Some(0)), (2, None)).toDF("a", "b")
+ lazy val right = Seq((1, Some(0)), (2, None)).toDF("a", "b")
+ testInnerJoin(
+ "inner join, null safe",
+ left,
+ right,
+ () => (left.col("b") <=> right.col("b")).expr,
+ Seq(
+ (1, 0, 1, 0),
+ (2, null, 2, null)
+ )
+ )
+ }
}