diff options
Diffstat (limited to 'sql')
4 files changed, 186 insertions, 19 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala index 1f76b03bcb..a5bdee1b85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala @@ -31,10 +31,19 @@ case class SortPartitions(sortExpressions: Seq[SortOrder], child: LogicalPlan) extends RedistributeData /** - * This method repartitions data using [[Expression]]s, and receives information about the - * number of partitions during execution. Used when a specific ordering or distribution is - * expected by the consumer of the query result. Use [[Repartition]] for RDD-like + * This method repartitions data using [[Expression]]s into `numPartitions`, and receives + * information about the number of partitions during execution. Used when a specific ordering or + * distribution is expected by the consumer of the query result. Use [[Repartition]] for RDD-like * `coalesce` and `repartition`. + * If `numPartitions` is not specified, the number of partitions will be the number set by + * `spark.sql.shuffle.partitions`. */ -case class RepartitionByExpression(partitionExpressions: Seq[Expression], child: LogicalPlan) - extends RedistributeData +case class RepartitionByExpression( + partitionExpressions: Seq[Expression], + child: LogicalPlan, + numPartitions: Option[Int] = None) extends RedistributeData { + numPartitions match { + case Some(n) => require(n > 0, "numPartitions must be greater than 0.") + case None => // Ok + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index aa817a037e..53ad3c0266 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -241,6 +241,18 @@ class DataFrame private[sql]( sb.toString() } + private[sql] def sortInternal(global: Boolean, sortExprs: Seq[Column]): DataFrame = { + val sortOrder: Seq[SortOrder] = sortExprs.map { col => + col.expr match { + case expr: SortOrder => + expr + case expr: Expression => + SortOrder(expr, Ascending) + } + } + Sort(sortOrder, global = global, logicalPlan) + } + override def toString: String = { try { schema.map(f => s"${f.name}: ${f.dataType.simpleString}").mkString("[", ", ", "]") @@ -633,15 +645,7 @@ class DataFrame private[sql]( */ @scala.annotation.varargs def sort(sortExprs: Column*): DataFrame = { - val sortOrder: Seq[SortOrder] = sortExprs.map { col => - col.expr match { - case expr: SortOrder => - expr - case expr: Expression => - SortOrder(expr, Ascending) - } - } - Sort(sortOrder, global = true, logicalPlan) + sortInternal(true, sortExprs) } /** @@ -663,6 +667,44 @@ class DataFrame private[sql]( def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs : _*) /** + * Returns a new [[DataFrame]] partitioned by the given partitioning expressions into + * `numPartitions`. The resulting DataFrame is hash partitioned. + * @group dfops + * @since 1.6.0 + */ + def distributeBy(partitionExprs: Seq[Column], numPartitions: Int): DataFrame = { + RepartitionByExpression(partitionExprs.map { _.expr }, logicalPlan, Some(numPartitions)) + } + + /** + * Returns a new [[DataFrame]] partitioned by the given partitioning expressions preserving + * the existing number of partitions. The resulting DataFrame is hash partitioned. + * @group dfops + * @since 1.6.0 + */ + def distributeBy(partitionExprs: Seq[Column]): DataFrame = { + RepartitionByExpression(partitionExprs.map { _.expr }, logicalPlan, None) + } + + /** + * Returns a new [[DataFrame]] with each partition sorted by the given expressions. + * @group dfops + * @since 1.6.0 + */ + @scala.annotation.varargs + def localSort(sortCol: String, sortCols: String*): DataFrame = localSort(sortCol, sortCols : _*) + + /** + * Returns a new [[DataFrame]] with each partition sorted by the given expressions. + * @group dfops + * @since 1.6.0 + */ + @scala.annotation.varargs + def localSort(sortExprs: Column*): DataFrame = { + sortInternal(false, sortExprs) + } + + /** * Selects column based on the column name and return it as a [[Column]]. * Note that the column name can also reference to a nested column like `a.b`. * @group dfops diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 86d1d390f1..f4464e0b91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -27,8 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} -import org.apache.spark.sql.types._ -import org.apache.spark.sql.{SQLContext, Strategy, execution} +import org.apache.spark.sql.{Strategy, execution} private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SparkPlanner => @@ -455,8 +454,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { generator, join = join, outer = outer, g.output, planLater(child)) :: Nil case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd, "OneRowRelation") :: Nil - case logical.RepartitionByExpression(expressions, child) => - execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil + case logical.RepartitionByExpression(expressions, child, nPartitions) => + execution.Exchange(HashPartitioning( + expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "PhysicalRDD") :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index c9d6e19d2c..6b86c5951b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -24,10 +24,14 @@ import scala.util.Random import org.scalatest.Matchers._ +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation +import org.apache.spark.sql.execution.Exchange +import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SQLTestData.TestData2 +import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SharedSQLContext} class DataFrameSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -997,4 +1001,116 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } } + + /** + * Verifies that there is no Exchange between the Aggregations for `df` + */ + private def verifyNonExchangingAgg(df: DataFrame) = { + var atFirstAgg: Boolean = false + df.queryExecution.executedPlan.foreach { + case agg: TungstenAggregate => { + atFirstAgg = !atFirstAgg + } + case _ => { + if (atFirstAgg) { + fail("Should not have operators between the two aggregations") + } + } + } + } + + /** + * Verifies that there is an Exchange between the Aggregations for `df` + */ + private def verifyExchangingAgg(df: DataFrame) = { + var atFirstAgg: Boolean = false + df.queryExecution.executedPlan.foreach { + case agg: TungstenAggregate => { + if (atFirstAgg) { + fail("Should not have back to back Aggregates") + } + atFirstAgg = true + } + case e: Exchange => atFirstAgg = false + case _ => + } + } + + test("distributeBy and localSort") { + val original = testData.repartition(1) + assert(original.rdd.partitions.length == 1) + val df = original.distributeBy(Column("key") :: Nil, 5) + assert(df.rdd.partitions.length == 5) + checkAnswer(original.select(), df.select()) + + val df2 = original.distributeBy(Column("key") :: Nil, 10) + assert(df2.rdd.partitions.length == 10) + checkAnswer(original.select(), df2.select()) + + // Group by the column we are distributed by. This should generate a plan with no exchange + // between the aggregates + val df3 = testData.distributeBy(Column("key") :: Nil).groupBy("key").count() + verifyNonExchangingAgg(df3) + verifyNonExchangingAgg(testData.distributeBy(Column("key") :: Column("value") :: Nil) + .groupBy("key", "value").count()) + + // Grouping by just the first distributeBy expr, need to exchange. + verifyExchangingAgg(testData.distributeBy(Column("key") :: Column("value") :: Nil) + .groupBy("key").count()) + + val data = sqlContext.sparkContext.parallelize( + (1 to 100).map(i => TestData2(i % 10, i))).toDF() + + // Distribute and order by. + val df4 = data.distributeBy(Column("a") :: Nil).localSort($"b".desc) + // Walk each partition and verify that it is sorted descending and does not contain all + // the values. + df4.rdd.foreachPartition(p => { + var previousValue: Int = -1 + var allSequential: Boolean = true + p.foreach(r => { + val v: Int = r.getInt(1) + if (previousValue != -1) { + if (previousValue < v) throw new SparkException("Partition is not ordered.") + if (v + 1 != previousValue) allSequential = false + } + previousValue = v + }) + if (allSequential) throw new SparkException("Partition should not be globally ordered") + }) + + // Distribute and order by with multiple order bys + val df5 = data.distributeBy(Column("a") :: Nil, 2).localSort($"b".asc, $"a".asc) + // Walk each partition and verify that it is sorted ascending + df5.rdd.foreachPartition(p => { + var previousValue: Int = -1 + var allSequential: Boolean = true + p.foreach(r => { + val v: Int = r.getInt(1) + if (previousValue != -1) { + if (previousValue > v) throw new SparkException("Partition is not ordered.") + if (v - 1 != previousValue) allSequential = false + } + previousValue = v + }) + if (allSequential) throw new SparkException("Partition should not be all sequential") + }) + + // Distribute into one partition and order by. This partition should contain all the values. + val df6 = data.distributeBy(Column("a") :: Nil, 1).localSort($"b".asc) + // Walk each partition and verify that it is sorted descending and not globally sorted. + df6.rdd.foreachPartition(p => { + var previousValue: Int = -1 + var allSequential: Boolean = true + p.foreach(r => { + val v: Int = r.getInt(1) + if (previousValue != -1) { + if (previousValue > v) throw new SparkException("Partition is not ordered.") + if (v - 1 != previousValue) allSequential = false + } + previousValue = v + }) + if (!allSequential) throw new SparkException("Partition should contain all sequential values") + }) + } } |