aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala60
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala118
4 files changed, 186 insertions, 19 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala
index 1f76b03bcb..a5bdee1b85 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala
@@ -31,10 +31,19 @@ case class SortPartitions(sortExpressions: Seq[SortOrder], child: LogicalPlan)
extends RedistributeData
/**
- * This method repartitions data using [[Expression]]s, and receives information about the
- * number of partitions during execution. Used when a specific ordering or distribution is
- * expected by the consumer of the query result. Use [[Repartition]] for RDD-like
+ * This method repartitions data using [[Expression]]s into `numPartitions`, and receives
+ * information about the number of partitions during execution. Used when a specific ordering or
+ * distribution is expected by the consumer of the query result. Use [[Repartition]] for RDD-like
* `coalesce` and `repartition`.
+ * If `numPartitions` is not specified, the number of partitions will be the number set by
+ * `spark.sql.shuffle.partitions`.
*/
-case class RepartitionByExpression(partitionExpressions: Seq[Expression], child: LogicalPlan)
- extends RedistributeData
+case class RepartitionByExpression(
+ partitionExpressions: Seq[Expression],
+ child: LogicalPlan,
+ numPartitions: Option[Int] = None) extends RedistributeData {
+ numPartitions match {
+ case Some(n) => require(n > 0, "numPartitions must be greater than 0.")
+ case None => // Ok
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index aa817a037e..53ad3c0266 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -241,6 +241,18 @@ class DataFrame private[sql](
sb.toString()
}
+ private[sql] def sortInternal(global: Boolean, sortExprs: Seq[Column]): DataFrame = {
+ val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
+ col.expr match {
+ case expr: SortOrder =>
+ expr
+ case expr: Expression =>
+ SortOrder(expr, Ascending)
+ }
+ }
+ Sort(sortOrder, global = global, logicalPlan)
+ }
+
override def toString: String = {
try {
schema.map(f => s"${f.name}: ${f.dataType.simpleString}").mkString("[", ", ", "]")
@@ -633,15 +645,7 @@ class DataFrame private[sql](
*/
@scala.annotation.varargs
def sort(sortExprs: Column*): DataFrame = {
- val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
- col.expr match {
- case expr: SortOrder =>
- expr
- case expr: Expression =>
- SortOrder(expr, Ascending)
- }
- }
- Sort(sortOrder, global = true, logicalPlan)
+ sortInternal(true, sortExprs)
}
/**
@@ -663,6 +667,44 @@ class DataFrame private[sql](
def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs : _*)
/**
+ * Returns a new [[DataFrame]] partitioned by the given partitioning expressions into
+ * `numPartitions`. The resulting DataFrame is hash partitioned.
+ * @group dfops
+ * @since 1.6.0
+ */
+ def distributeBy(partitionExprs: Seq[Column], numPartitions: Int): DataFrame = {
+ RepartitionByExpression(partitionExprs.map { _.expr }, logicalPlan, Some(numPartitions))
+ }
+
+ /**
+ * Returns a new [[DataFrame]] partitioned by the given partitioning expressions preserving
+ * the existing number of partitions. The resulting DataFrame is hash partitioned.
+ * @group dfops
+ * @since 1.6.0
+ */
+ def distributeBy(partitionExprs: Seq[Column]): DataFrame = {
+ RepartitionByExpression(partitionExprs.map { _.expr }, logicalPlan, None)
+ }
+
+ /**
+ * Returns a new [[DataFrame]] with each partition sorted by the given expressions.
+ * @group dfops
+ * @since 1.6.0
+ */
+ @scala.annotation.varargs
+ def localSort(sortCol: String, sortCols: String*): DataFrame = localSort(sortCol, sortCols : _*)
+
+ /**
+ * Returns a new [[DataFrame]] with each partition sorted by the given expressions.
+ * @group dfops
+ * @since 1.6.0
+ */
+ @scala.annotation.varargs
+ def localSort(sortExprs: Column*): DataFrame = {
+ sortInternal(false, sortExprs)
+ }
+
+ /**
* Selects column based on the column name and return it as a [[Column]].
* Note that the column name can also reference to a nested column like `a.b`.
* @group dfops
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 86d1d390f1..f4464e0b91 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -27,8 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _}
import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand}
-import org.apache.spark.sql.types._
-import org.apache.spark.sql.{SQLContext, Strategy, execution}
+import org.apache.spark.sql.{Strategy, execution}
private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
self: SparkPlanner =>
@@ -455,8 +454,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
generator, join = join, outer = outer, g.output, planLater(child)) :: Nil
case logical.OneRowRelation =>
execution.PhysicalRDD(Nil, singleRowRdd, "OneRowRelation") :: Nil
- case logical.RepartitionByExpression(expressions, child) =>
- execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil
+ case logical.RepartitionByExpression(expressions, child, nPartitions) =>
+ execution.Exchange(HashPartitioning(
+ expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil
case e @ EvaluatePython(udf, child, _) =>
BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "PhysicalRDD") :: Nil
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index c9d6e19d2c..6b86c5951b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -24,10 +24,14 @@ import scala.util.Random
import org.scalatest.Matchers._
+import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
+import org.apache.spark.sql.execution.Exchange
+import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SQLTestData.TestData2
+import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext}
import org.apache.spark.sql.types._
-import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SharedSQLContext}
class DataFrameSuite extends QueryTest with SharedSQLContext {
import testImplicits._
@@ -997,4 +1001,116 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
}
}
+
+ /**
+ * Verifies that there is no Exchange between the Aggregations for `df`
+ */
+ private def verifyNonExchangingAgg(df: DataFrame) = {
+ var atFirstAgg: Boolean = false
+ df.queryExecution.executedPlan.foreach {
+ case agg: TungstenAggregate => {
+ atFirstAgg = !atFirstAgg
+ }
+ case _ => {
+ if (atFirstAgg) {
+ fail("Should not have operators between the two aggregations")
+ }
+ }
+ }
+ }
+
+ /**
+ * Verifies that there is an Exchange between the Aggregations for `df`
+ */
+ private def verifyExchangingAgg(df: DataFrame) = {
+ var atFirstAgg: Boolean = false
+ df.queryExecution.executedPlan.foreach {
+ case agg: TungstenAggregate => {
+ if (atFirstAgg) {
+ fail("Should not have back to back Aggregates")
+ }
+ atFirstAgg = true
+ }
+ case e: Exchange => atFirstAgg = false
+ case _ =>
+ }
+ }
+
+ test("distributeBy and localSort") {
+ val original = testData.repartition(1)
+ assert(original.rdd.partitions.length == 1)
+ val df = original.distributeBy(Column("key") :: Nil, 5)
+ assert(df.rdd.partitions.length == 5)
+ checkAnswer(original.select(), df.select())
+
+ val df2 = original.distributeBy(Column("key") :: Nil, 10)
+ assert(df2.rdd.partitions.length == 10)
+ checkAnswer(original.select(), df2.select())
+
+ // Group by the column we are distributed by. This should generate a plan with no exchange
+ // between the aggregates
+ val df3 = testData.distributeBy(Column("key") :: Nil).groupBy("key").count()
+ verifyNonExchangingAgg(df3)
+ verifyNonExchangingAgg(testData.distributeBy(Column("key") :: Column("value") :: Nil)
+ .groupBy("key", "value").count())
+
+ // Grouping by just the first distributeBy expr, need to exchange.
+ verifyExchangingAgg(testData.distributeBy(Column("key") :: Column("value") :: Nil)
+ .groupBy("key").count())
+
+ val data = sqlContext.sparkContext.parallelize(
+ (1 to 100).map(i => TestData2(i % 10, i))).toDF()
+
+ // Distribute and order by.
+ val df4 = data.distributeBy(Column("a") :: Nil).localSort($"b".desc)
+ // Walk each partition and verify that it is sorted descending and does not contain all
+ // the values.
+ df4.rdd.foreachPartition(p => {
+ var previousValue: Int = -1
+ var allSequential: Boolean = true
+ p.foreach(r => {
+ val v: Int = r.getInt(1)
+ if (previousValue != -1) {
+ if (previousValue < v) throw new SparkException("Partition is not ordered.")
+ if (v + 1 != previousValue) allSequential = false
+ }
+ previousValue = v
+ })
+ if (allSequential) throw new SparkException("Partition should not be globally ordered")
+ })
+
+ // Distribute and order by with multiple order bys
+ val df5 = data.distributeBy(Column("a") :: Nil, 2).localSort($"b".asc, $"a".asc)
+ // Walk each partition and verify that it is sorted ascending
+ df5.rdd.foreachPartition(p => {
+ var previousValue: Int = -1
+ var allSequential: Boolean = true
+ p.foreach(r => {
+ val v: Int = r.getInt(1)
+ if (previousValue != -1) {
+ if (previousValue > v) throw new SparkException("Partition is not ordered.")
+ if (v - 1 != previousValue) allSequential = false
+ }
+ previousValue = v
+ })
+ if (allSequential) throw new SparkException("Partition should not be all sequential")
+ })
+
+ // Distribute into one partition and order by. This partition should contain all the values.
+ val df6 = data.distributeBy(Column("a") :: Nil, 1).localSort($"b".asc)
+ // Walk each partition and verify that it is sorted descending and not globally sorted.
+ df6.rdd.foreachPartition(p => {
+ var previousValue: Int = -1
+ var allSequential: Boolean = true
+ p.foreach(r => {
+ val v: Int = r.getInt(1)
+ if (previousValue != -1) {
+ if (previousValue > v) throw new SparkException("Partition is not ordered.")
+ if (v - 1 != previousValue) allSequential = false
+ }
+ previousValue = v
+ })
+ if (!allSequential) throw new SparkException("Partition should contain all sequential values")
+ })
+ }
}