aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-10-13 09:40:36 -0700
committerYin Huai <yhuai@databricks.com>2015-10-13 09:43:33 -0700
commit6987c067937a50867b4d5788f5bf496ecdfdb62c (patch)
treef5af8b048dc16f9c61481e7136848274b648ca60 /sql
parent626aab79c9b4d4ac9d65bf5fa45b81dd9cbc609c (diff)
downloadspark-6987c067937a50867b4d5788f5bf496ecdfdb62c.tar.gz
spark-6987c067937a50867b4d5788f5bf496ecdfdb62c.tar.bz2
spark-6987c067937a50867b4d5788f5bf496ecdfdb62c.zip
[SPARK-11009] [SQL] fix wrong result of Window function in cluster mode
Currently, All windows function could generate wrong result in cluster sometimes. The root cause is that AttributeReference is called in executor, then id of it may not be unique than others created in driver. Here is the script that could reproduce the problem (run in local cluster): ``` from pyspark import SparkContext, HiveContext from pyspark.sql.window import Window from pyspark.sql.functions import rowNumber sqlContext = HiveContext(SparkContext()) sqlContext.setConf("spark.sql.shuffle.partitions", "3") df = sqlContext.range(1<<20) df2 = df.select((df.id % 1000).alias("A"), (df.id / 1000).alias('B')) ws = Window.partitionBy(df2.A).orderBy(df2.B) df3 = df2.select("client", "date", rowNumber().over(ws).alias("rn")).filter("rn < 0") assert df3.count() == 0 ``` Author: Davies Liu <davies@databricks.com> Author: Yin Huai <yhuai@databricks.com> Closes #9050 from davies/wrong_window.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala20
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala41
2 files changed, 51 insertions, 10 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
index f8929530c5..55035f4bc5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
@@ -145,11 +145,10 @@ case class Window(
// Construct the ordering. This is used to compare the result of current value projection
// to the result of bound value projection. This is done manually because we want to use
// Code Generation (if it is enabled).
- val (sortExprs, schema) = exprs.map { case e =>
- val ref = AttributeReference("ordExpr", e.dataType, e.nullable)()
- (SortOrder(ref, e.direction), ref)
- }.unzip
- val ordering = newOrdering(sortExprs, schema)
+ val sortExprs = exprs.zipWithIndex.map { case (e, i) =>
+ SortOrder(BoundReference(i, e.dataType, e.nullable), e.direction)
+ }
+ val ordering = newOrdering(sortExprs, Nil)
RangeBoundOrdering(ordering, current, bound)
case RowFrame => RowBoundOrdering(offset)
}
@@ -205,14 +204,15 @@ case class Window(
*/
private[this] def createResultProjection(
expressions: Seq[Expression]): MutableProjection = {
- val unboundToAttr = expressions.map {
- e => (e, AttributeReference("windowResult", e.dataType, e.nullable)())
+ val references = expressions.zipWithIndex.map{ case (e, i) =>
+ // Results of window expressions will be on the right side of child's output
+ BoundReference(child.output.size + i, e.dataType, e.nullable)
}
- val unboundToAttrMap = unboundToAttr.toMap
- val patchedWindowExpression = windowExpression.map(_.transform(unboundToAttrMap))
+ val unboundToRefMap = expressions.zip(references).toMap
+ val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap))
newMutableProjection(
projectList ++ patchedWindowExpression,
- child.output ++ unboundToAttr.map(_._2))()
+ child.output)()
}
protected override def doExecute(): RDD[InternalRow] = {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
index 5f1660b62d..10e4ae2c50 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
@@ -30,6 +30,7 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark._
import org.apache.spark.sql.{SQLContext, QueryTest}
+import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext}
import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer
import org.apache.spark.sql.types.DecimalType
@@ -107,6 +108,16 @@ class HiveSparkSubmitSuite
runSparkSubmit(args)
}
+ test("SPARK-11009 fix wrong result of Window function in cluster mode") {
+ val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
+ val args = Seq(
+ "--class", SPARK_11009.getClass.getName.stripSuffix("$"),
+ "--name", "SparkSQLConfTest",
+ "--master", "local-cluster[2,1,1024]",
+ unusedJar.toString)
+ runSparkSubmit(args)
+ }
+
// NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly.
// This is copied from org.apache.spark.deploy.SparkSubmitSuite
private def runSparkSubmit(args: Seq[String]): Unit = {
@@ -320,3 +331,33 @@ object SPARK_9757 extends QueryTest {
}
}
}
+
+object SPARK_11009 extends QueryTest {
+ import org.apache.spark.sql.functions._
+
+ protected var sqlContext: SQLContext = _
+
+ def main(args: Array[String]): Unit = {
+ Utils.configTestLog4j("INFO")
+
+ val sparkContext = new SparkContext(
+ new SparkConf()
+ .set("spark.ui.enabled", "false")
+ .set("spark.sql.shuffle.partitions", "100"))
+
+ val hiveContext = new TestHiveContext(sparkContext)
+ sqlContext = hiveContext
+
+ try {
+ val df = sqlContext.range(1 << 20)
+ val df2 = df.select((df("id") % 1000).alias("A"), (df("id") / 1000).alias("B"))
+ val ws = Window.partitionBy(df2("A")).orderBy(df2("B"))
+ val df3 = df2.select(df2("A"), df2("B"), rowNumber().over(ws).alias("rn")).filter("rn < 0")
+ if (df3.rdd.count() != 0) {
+ throw new Exception("df3 should have 0 output row.")
+ }
+ } finally {
+ sparkContext.stop()
+ }
+ }
+}