aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-11-18 10:15:50 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-18 10:15:50 -0800
commitcffb899c4397ecccedbcc41e7cf3da91f953435a (patch)
tree56a30a519b936cf945ec59cc3337faf123012a8e
parent3cca5ffb3d60d5de9a54bc71cf0b8279898936d2 (diff)
downloadspark-cffb899c4397ecccedbcc41e7cf3da91f953435a.tar.gz
spark-cffb899c4397ecccedbcc41e7cf3da91f953435a.tar.bz2
spark-cffb899c4397ecccedbcc41e7cf3da91f953435a.zip
[SPARK-11803][SQL] fix Dataset self-join
When we resolve the join operator, we may change the output of right side if self-join is detected. So in `Dataset.joinWith`, we should resolve the join operator first, and then get the left output and right output from it, instead of using `left.output` and `right.output` directly. Author: Wenchen Fan <wenchen@databricks.com> Closes #9806 from cloud-fan/self-join.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala8
2 files changed, 13 insertions, 9 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 817c20fdbb..b644f6ad30 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -498,13 +498,17 @@ class Dataset[T] private[sql](
val left = this.logicalPlan
val right = other.logicalPlan
+ val joined = sqlContext.executePlan(Join(left, right, Inner, Some(condition.expr)))
+ val leftOutput = joined.analyzed.output.take(left.output.length)
+ val rightOutput = joined.analyzed.output.takeRight(right.output.length)
+
val leftData = this.unresolvedTEncoder match {
- case e if e.flat => Alias(left.output.head, "_1")()
- case _ => Alias(CreateStruct(left.output), "_1")()
+ case e if e.flat => Alias(leftOutput.head, "_1")()
+ case _ => Alias(CreateStruct(leftOutput), "_1")()
}
val rightData = other.unresolvedTEncoder match {
- case e if e.flat => Alias(right.output.head, "_2")()
- case _ => Alias(CreateStruct(right.output), "_2")()
+ case e if e.flat => Alias(rightOutput.head, "_2")()
+ case _ => Alias(CreateStruct(rightOutput), "_2")()
}
@@ -513,7 +517,7 @@ class Dataset[T] private[sql](
withPlan[(T, U)](other) { (left, right) =>
Project(
leftData :: rightData :: Nil,
- Join(left, right, Inner, Some(condition.expr)))
+ joined.analyzed)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index a522894c37..198962b8fb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -347,7 +347,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkAnswer(joined, ("2", 2))
}
- ignore("self join") {
+ test("self join") {
val ds = Seq("1", "2").toDS().as("a")
val joined = ds.joinWith(ds, lit(true))
checkAnswer(joined, ("1", "1"), ("1", "2"), ("2", "1"), ("2", "2"))
@@ -360,15 +360,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("kryo encoder") {
implicit val kryoEncoder = Encoders.kryo[KryoData]
- val ds = sqlContext.createDataset(Seq(KryoData(1), KryoData(2)))
+ val ds = Seq(KryoData(1), KryoData(2)).toDS()
assert(ds.groupBy(p => p).count().collect().toSeq ==
Seq((KryoData(1), 1L), (KryoData(2), 1L)))
}
- ignore("kryo encoder self join") {
+ test("kryo encoder self join") {
implicit val kryoEncoder = Encoders.kryo[KryoData]
- val ds = sqlContext.createDataset(Seq(KryoData(1), KryoData(2)))
+ val ds = Seq(KryoData(1), KryoData(2)).toDS()
assert(ds.joinWith(ds, lit(true)).collect().toSet ==
Set(
(KryoData(1), KryoData(1)),