aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-09-22 13:31:35 -0700
committerYin Huai <yhuai@databricks.com>2015-09-22 13:31:35 -0700
commit5aea987c904b281d7952ad8db40a32561b4ec5cf (patch)
tree163ace452106c5ce5975302ea46e022d4c5e870e /sql
parent2204cdb28483b249616068085d4e88554fe6acef (diff)
downloadspark-5aea987c904b281d7952ad8db40a32561b4ec5cf.tar.gz
spark-5aea987c904b281d7952ad8db40a32561b4ec5cf.tar.bz2
spark-5aea987c904b281d7952ad8db40a32561b4ec5cf.zip
[SPARK-10737] [SQL] When using UnsafeRows, SortMergeJoin may return wrong results
https://issues.apache.org/jira/browse/SPARK-10737 Author: Yin Huai <yhuai@databricks.com> Closes #8854 from yhuai/SMJBug.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala25
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala28
4 files changed, 59 insertions, 5 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index 2164ddf03d..75524b568d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -171,6 +171,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
@Override
public Object apply(Object r) {
+ // GenerateProjection does not work with UnsafeRows.
+ assert(!(r instanceof ${classOf[UnsafeRow].getName}));
return new SpecificRow((InternalRow) r);
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
index 0269d6d4b7..f8929530c5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
@@ -253,7 +253,11 @@ case class Window(
// Get all relevant projections.
val result = createResultProjection(unboundExpressions)
- val grouping = newProjection(partitionSpec, child.output)
+ val grouping = if (child.outputsUnsafeRows) {
+ UnsafeProjection.create(partitionSpec, child.output)
+ } else {
+ newProjection(partitionSpec, child.output)
+ }
// Manage the stream and the grouping.
var nextRow: InternalRow = EmptyRow
@@ -277,7 +281,8 @@ case class Window(
val numFrames = frames.length
private[this] def fetchNextPartition() {
// Collect all the rows in the current partition.
- val currentGroup = nextGroup
+ // Before we start to fetch new input rows, make a copy of nextGroup.
+ val currentGroup = nextGroup.copy()
rows = new CompactBuffer
while (nextRowAvailable && nextGroup == currentGroup) {
rows += nextRow.copy()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
index 906f20d2a7..70a1af6a70 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
@@ -56,9 +56,6 @@ case class SortMergeJoin(
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil
- @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output)
- @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output)
-
protected[this] def isUnsafeMode: Boolean = {
(codegenEnabled && unsafeEnabled
&& UnsafeProjection.canSupport(leftKeys)
@@ -82,6 +79,28 @@ case class SortMergeJoin(
left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
new RowIterator {
+ // The projection used to extract keys from input rows of the left child.
+ private[this] val leftKeyGenerator = {
+ if (isUnsafeMode) {
+ // It is very important to use UnsafeProjection if input rows are UnsafeRows.
+ // Otherwise, GenerateProjection will cause wrong results.
+ UnsafeProjection.create(leftKeys, left.output)
+ } else {
+ newProjection(leftKeys, left.output)
+ }
+ }
+
+ // The projection used to extract keys from input rows of the right child.
+ private[this] val rightKeyGenerator = {
+ if (isUnsafeMode) {
+ // It is very important to use UnsafeProjection if input rows are UnsafeRows.
+ // Otherwise, GenerateProjection will cause wrong results.
+ UnsafeProjection.create(rightKeys, right.output)
+ } else {
+ newProjection(rightKeys, right.output)
+ }
+ }
+
// An ordering that can be used to compare keys from both sides.
private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
private[this] var currentLeftRow: InternalRow = _
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 05b4127cbc..eca6f10738 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1781,4 +1781,32 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
Seq(Row(1), Row(1)))
}
}
+
+ test("SortMergeJoin returns wrong results when using UnsafeRows") {
+ // This test is for the fix of https://issues.apache.org/jira/browse/SPARK-10737.
+ // This bug will be triggered when Tungsten is enabled and there are multiple
+ // SortMergeJoin operators executed in the same task.
+ val confs =
+ SQLConf.SORTMERGE_JOIN.key -> "true" ::
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1" ::
+ SQLConf.TUNGSTEN_ENABLED.key -> "true" :: Nil
+ withSQLConf(confs: _*) {
+ val df1 = (1 to 50).map(i => (s"str_$i", i)).toDF("i", "j")
+ val df2 =
+ df1
+ .join(df1.select(df1("i")), "i")
+ .select(df1("i"), df1("j"))
+
+ val df3 = df2.withColumnRenamed("i", "i1").withColumnRenamed("j", "j1")
+ val df4 =
+ df2
+ .join(df3, df2("i") === df3("i1"))
+ .withColumn("diff", $"j" - $"j1")
+ .select(df2("i"), df2("j"), $"diff")
+
+ checkAnswer(
+ df4,
+ df1.withColumn("diff", lit(0)))
+ }
+ }
}