aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-06-01 16:16:54 -0700
committerCheng Lian <lian@databricks.com>2016-06-01 16:16:54 -0700
commit8640cdb836b4964e4af891d9959af64a2e1f304e (patch)
treef2a4c1da3aa44ae75189b3476b3039cc0ab5ac78 /sql
parent7bb64aae27f670531699f59d3f410e38866609b7 (diff)
downloadspark-8640cdb836b4964e4af891d9959af64a2e1f304e.tar.gz
spark-8640cdb836b4964e4af891d9959af64a2e1f304e.tar.bz2
spark-8640cdb836b4964e4af891d9959af64a2e1f304e.zip
[SPARK-15441][SQL] support null object in Dataset outer-join
## What changes were proposed in this pull request? Currently we can't encode top level null object into internal row, as Spark SQL doesn't allow row to be null, only its columns can be null. This is not a problem before, as we assume the input object is never null. However, for outer join, we do need the semantics of null object. This PR fixes this problem by making both join sides produce a single column, i.e. nest the logical plan output(by `CreateStruct`), so that we have an extra level to represent top level null obejct. ## How was this patch tested? new test in `DatasetSuite` Author: Wenchen Fan <wenchen@databricks.com> Closes #13425 from cloud-fan/outer-join2.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala67
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala23
4 files changed, 59 insertions, 35 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index f21a39a2d4..2296946cd7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -125,12 +125,13 @@ object ExpressionEncoder {
}
} else {
val input = BoundReference(index, enc.schema, nullable = true)
- enc.deserializer.transformUp {
+ val deserialized = enc.deserializer.transformUp {
case UnresolvedAttribute(nameParts) =>
assert(nameParts.length == 1)
UnresolvedExtractValue(input, Literal(nameParts.head))
case BoundReference(ordinal, dt, _) => GetStructField(input, ordinal)
}
+ If(IsNull(input), Literal.create(null, deserialized.dataType), deserialized)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 2f2323fa3a..c2e3ab82ff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions.objects
import java.lang.reflect.Modifier
-import scala.annotation.tailrec
import scala.language.existentials
import scala.reflect.ClassTag
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 3a6ec4595e..369b772d32 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -747,31 +747,62 @@ class Dataset[T] private[sql](
*/
@Experimental
def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = {
- val left = this.logicalPlan
- val right = other.logicalPlan
-
- val joined = sparkSession.sessionState.executePlan(Join(left, right, joinType =
- JoinType(joinType), Some(condition.expr)))
- val leftOutput = joined.analyzed.output.take(left.output.length)
- val rightOutput = joined.analyzed.output.takeRight(right.output.length)
+ // Creates a Join node and resolve it first, to get join condition resolved, self-join resolved,
+ // etc.
+ val joined = sparkSession.sessionState.executePlan(
+ Join(
+ this.logicalPlan,
+ other.logicalPlan,
+ JoinType(joinType),
+ Some(condition.expr))).analyzed.asInstanceOf[Join]
+
+ // For both join side, combine all outputs into a single column and alias it with "_1" or "_2",
+ // to match the schema for the encoder of the join result.
+ // Note that we do this before joining them, to enable the join operator to return null for one
+ // side, in cases like outer-join.
+ val left = {
+ val combined = if (this.unresolvedTEncoder.flat) {
+ assert(joined.left.output.length == 1)
+ Alias(joined.left.output.head, "_1")()
+ } else {
+ Alias(CreateStruct(joined.left.output), "_1")()
+ }
+ Project(combined :: Nil, joined.left)
+ }
- val leftData = this.unresolvedTEncoder match {
- case e if e.flat => Alias(leftOutput.head, "_1")()
- case _ => Alias(CreateStruct(leftOutput), "_1")()
+ val right = {
+ val combined = if (other.unresolvedTEncoder.flat) {
+ assert(joined.right.output.length == 1)
+ Alias(joined.right.output.head, "_2")()
+ } else {
+ Alias(CreateStruct(joined.right.output), "_2")()
+ }
+ Project(combined :: Nil, joined.right)
}
- val rightData = other.unresolvedTEncoder match {
- case e if e.flat => Alias(rightOutput.head, "_2")()
- case _ => Alias(CreateStruct(rightOutput), "_2")()
+
+ // Rewrites the join condition to make the attribute point to correct column/field, after we
+ // combine the outputs of each join side.
+ val conditionExpr = joined.condition.get transformUp {
+ case a: Attribute if joined.left.outputSet.contains(a) =>
+ if (this.unresolvedTEncoder.flat) {
+ left.output.head
+ } else {
+ val index = joined.left.output.indexWhere(_.exprId == a.exprId)
+ GetStructField(left.output.head, index)
+ }
+ case a: Attribute if joined.right.outputSet.contains(a) =>
+ if (other.unresolvedTEncoder.flat) {
+ right.output.head
+ } else {
+ val index = joined.right.output.indexWhere(_.exprId == a.exprId)
+ GetStructField(right.output.head, index)
+ }
}
implicit val tuple2Encoder: Encoder[(T, U)] =
ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder)
- withTypedPlan {
- Project(
- leftData :: rightData :: Nil,
- joined.analyzed)
- }
+ withTypedPlan(Join(left, right, joined.joinType, Some(conditionExpr)))
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 8fc4dc9f17..0b6874e3b8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -253,21 +253,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
(1, 1), (2, 2))
}
- test("joinWith, expression condition, outer join") {
- val nullInteger = null.asInstanceOf[Integer]
- val nullString = null.asInstanceOf[String]
- val ds1 = Seq(ClassNullableData("a", 1),
- ClassNullableData("c", 3)).toDS()
- val ds2 = Seq(("a", new Integer(1)),
- ("b", new Integer(2))).toDS()
-
- checkDataset(
- ds1.joinWith(ds2, $"_1" === $"a", "outer"),
- (ClassNullableData("a", 1), ("a", new Integer(1))),
- (ClassNullableData("c", 3), (nullString, nullInteger)),
- (ClassNullableData(nullString, nullInteger), ("b", new Integer(2))))
- }
-
test("joinWith tuple with primitive, expression") {
val ds1 = Seq(1, 1, 2).toDS()
val ds2 = Seq(("a", 1), ("b", 2)).toDS()
@@ -783,6 +768,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
ds.filter(_.b > 1).collect().toSeq
}
}
+
+ test("SPARK-15441: Dataset outer join") {
+ val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDS().as("left")
+ val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDS().as("right")
+ val joined = left.joinWith(right, $"left.b" === $"right.b", "left")
+ val result = joined.collect().toSet
+ assert(result == Set(ClassData("a", 1) -> null, ClassData("b", 2) -> ClassData("x", 2)))
+ }
}
case class Generic[T](id: T, value: Double)