From 84acd08e0886aa23195f35837c15c09aa7804aff Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 4 Feb 2015 19:53:57 -0800 Subject: [SPARK-5602][SQL] Better support for creating DataFrame from local data collection 1. Added methods to create DataFrames from Seq[Product] 2. Added executeTake to avoid running a Spark job on LocalRelations. Author: Reynold Xin Closes #4372 from rxin/localDataFrame and squashes the following commits: f696858 [Reynold Xin] style checker. 839ef7f [Reynold Xin] [SPARK-5602][SQL] Better support for creating DataFrame from local data collection. --- .../spark/sql/catalyst/ScalaReflection.scala | 2 +- .../sql/catalyst/plans/logical/LocalRelation.scala | 57 ++++++++++++++++++++++ .../sql/catalyst/plans/logical/TestRelation.scala | 54 -------------------- .../spark/sql/types/DataTypeConversions.scala | 16 ++++++ .../scala/org/apache/spark/sql/SQLContext.scala | 41 ++++++++++++++-- .../apache/spark/sql/execution/ExistingRDD.scala | 36 +++++--------- .../spark/sql/execution/LocalTableScan.scala | 36 ++++++++++++++ .../org/apache/spark/sql/execution/SparkPlan.scala | 49 ++++++++++++++++++- .../spark/sql/execution/SparkStrategies.scala | 9 +--- .../spark/sql/execution/basicOperators.scala | 44 +---------------- .../org/apache/spark/sql/execution/commands.scala | 2 + 11 files changed, 214 insertions(+), 132 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 8e79e532ca..0445f3aa07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -211,7 +211,7 @@ trait ScalaReflection { */ def asRelation: LocalRelation = { val output = attributesFor[A] - LocalRelation(output, data) + LocalRelation.fromProduct(output, data) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala new file mode 100644 index 0000000000..92bd057c6f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.types.{DataTypeConversions, StructType, StructField} + +object LocalRelation { + def apply(output: Attribute*): LocalRelation = new LocalRelation(output) + + def apply(output1: StructField, output: StructField*): LocalRelation = { + new LocalRelation(StructType(output1 +: output).toAttributes) + } + + def fromProduct(output: Seq[Attribute], data: Seq[Product]): LocalRelation = { + val schema = StructType.fromAttributes(output) + LocalRelation(output, data.map(row => DataTypeConversions.productToRow(row, schema))) + } +} + +case class LocalRelation(output: Seq[Attribute], data: Seq[Row] = Nil) + extends LeafNode with analysis.MultiInstanceRelation { + + /** + * Returns an identical copy of this relation with new exprIds for all attributes. Different + * attributes are required when a relation is going to be included multiple times in the same + * query. + */ + override final def newInstance(): this.type = { + LocalRelation(output.map(_.newInstance()), data).asInstanceOf[this.type] + } + + override protected def stringArgs = Iterator(output) + + override def sameResult(plan: LogicalPlan): Boolean = plan match { + case LocalRelation(otherOutput, otherData) => + otherOutput.map(_.dataType) == output.map(_.dataType) && otherData == data + case _ => false + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala deleted file mode 100644 index d90af45b37..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.plans.logical - -import org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.types.{StructType, StructField} - -object LocalRelation { - def apply(output: Attribute*): LocalRelation = new LocalRelation(output) - - def apply(output1: StructField, output: StructField*): LocalRelation = new LocalRelation( - StructType(output1 +: output).toAttributes - ) -} - -case class LocalRelation(output: Seq[Attribute], data: Seq[Product] = Nil) - extends LeafNode with analysis.MultiInstanceRelation { - - // TODO: Validate schema compliance. - def loadData(newData: Seq[Product]) = new LocalRelation(output, data ++ newData) - - /** - * Returns an identical copy of this relation with new exprIds for all attributes. Different - * attributes are required when a relation is going to be included multiple times in the same - * query. - */ - override final def newInstance: this.type = { - LocalRelation(output.map(_.newInstance), data).asInstanceOf[this.type] - } - - override protected def stringArgs = Iterator(output) - - override def sameResult(plan: LogicalPlan): Boolean = plan match { - case LocalRelation(otherOutput, otherData) => - otherOutput.map(_.dataType) == output.map(_.dataType) && otherData == data - case _ => false - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala index 21f478c80c..c243be07a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala @@ -19,11 +19,27 @@ package org.apache.spark.sql.types import java.text.SimpleDateFormat +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow protected[sql] object DataTypeConversions { + def productToRow(product: Product, schema: StructType): Row = { + val mutableRow = new GenericMutableRow(product.productArity) + val schemaFields = schema.fields.toArray + + var i = 0 + while (i < mutableRow.length) { + mutableRow(i) = + ScalaReflection.convertToCatalyst(product.productElement(i), schemaFields(i).dataType) + i += 1 + } + + mutableRow + } + def stringToTime(s: String): java.util.Date = { if (!s.contains('T')) { // JDBC escape string diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 1661282fc3..5ab5494f80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.execution._ import org.apache.spark.sql.json._ @@ -163,17 +163,52 @@ class SQLContext(@transient val sparkContext: SparkContext) /** Removes the specified table from the in-memory cache. */ def uncacheTable(tableName: String): Unit = cacheManager.uncacheTable(tableName) + // scalastyle:off + // Disable style checker so "implicits" object can start with lowercase i + /** + * Implicit methods available in Scala for converting common Scala objects into [[DataFrame]]s. + */ + object implicits { + // scalastyle:on + /** + * Creates a DataFrame from an RDD of case classes. + * + * @group userf + */ + implicit def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { + self.createDataFrame(rdd) + } + + /** + * Creates a DataFrame from a local Seq of Product. + */ + implicit def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { + self.createDataFrame(data) + } + } + /** * Creates a DataFrame from an RDD of case classes. * * @group userf */ - implicit def createDataFrame[A <: Product: TypeTag](rdd: RDD[A]): DataFrame = { + // TODO: Remove implicit here. + implicit def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { SparkPlan.currentContext.set(self) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes val rowRDD = RDDConversions.productToRowRdd(rdd, schema) - DataFrame(this, LogicalRDD(attributeSeq, rowRDD)(self)) + DataFrame(self, LogicalRDD(attributeSeq, rowRDD)(self)) + } + + /** + * Creates a DataFrame from a local Seq of Product. + */ + def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { + SparkPlan.currentContext.set(self) + val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] + val attributeSeq = schema.toAttributes + DataFrame(self, LocalRelation.fromProduct(attributeSeq, data)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 20b14834bb..248dc1512b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -54,12 +54,13 @@ object RDDConversions { } } +/** Logical plan node for scanning data from an RDD. */ case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLContext) extends LogicalPlan with MultiInstanceRelation { - def children = Nil + override def children = Nil - def newInstance() = + override def newInstance() = LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type] override def sameResult(plan: LogicalPlan) = plan match { @@ -74,39 +75,28 @@ case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLCont ) } +/** Physical plan node for scanning data from an RDD. */ case class PhysicalRDD(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { override def execute() = rdd } -@deprecated("Use LogicalRDD", "1.2.0") -case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { - override def execute() = rdd -} - -@deprecated("Use LogicalRDD", "1.2.0") -case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQLContext) - extends LogicalPlan with MultiInstanceRelation { +/** Logical plan node for scanning data from a local collection. */ +case class LogicalLocalTable(output: Seq[Attribute], rows: Seq[Row])(sqlContext: SQLContext) + extends LogicalPlan with MultiInstanceRelation { - def output = alreadyPlanned.output override def children = Nil - override final def newInstance(): this.type = { - SparkLogicalPlan( - alreadyPlanned match { - case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance()), rdd) - case _ => sys.error("Multiple instance of the same relation detected.") - })(sqlContext).asInstanceOf[this.type] - } + override def newInstance() = + LogicalLocalTable(output.map(_.newInstance()), rows)(sqlContext).asInstanceOf[this.type] override def sameResult(plan: LogicalPlan) = plan match { - case SparkLogicalPlan(ExistingRdd(_, rdd)) => - rdd.id == alreadyPlanned.asInstanceOf[ExistingRdd].rdd.id + case LogicalRDD(_, otherRDD) => rows == rows case _ => false } @transient override lazy val statistics = Statistics( - // TODO: Instead of returning a default value here, find a way to return a meaningful size - // estimate for RDDs. See PR 1238 for more discussions. - sizeInBytes = BigInt(sqlContext.conf.defaultSizeInBytes) + // TODO: Improve the statistics estimation. + // This is made small enough so it can be broadcasted. + sizeInBytes = sqlContext.conf.autoBroadcastJoinThreshold - 1 ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala new file mode 100644 index 0000000000..d6d8258f46 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.Attribute + + +/** + * Physical plan node for scanning data from a local collection. + */ +case class LocalTableScan(output: Seq[Attribute], rows: Seq[Row]) extends LeafNode { + + private lazy val rdd = sqlContext.sparkContext.parallelize(rows) + + override def execute() = rdd + + override def executeCollect() = rows.toArray + + override def executeTake(limit: Int) = rows.take(limit).toArray +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 6fecd1ff06..052766c20a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -27,6 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ +import scala.collection.mutable.ArrayBuffer + object SparkPlan { protected[sql] val currentContext = new ThreadLocal[SQLContext]() } @@ -77,8 +79,53 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** * Runs this query returning the result as an array. */ - def executeCollect(): Array[Row] = + def executeCollect(): Array[Row] = { execute().map(ScalaReflection.convertRowToScala(_, schema)).collect() + } + + /** + * Runs this query returning the first `n` rows as an array. + * + * This is modeled after RDD.take but never runs any job locally on the driver. + */ + def executeTake(n: Int): Array[Row] = { + if (n == 0) { + return new Array[Row](0) + } + + val childRDD = execute().map(_.copy()) + + val buf = new ArrayBuffer[Row] + val totalParts = childRDD.partitions.length + var partsScanned = 0 + while (buf.size < n && partsScanned < totalParts) { + // The number of partitions to try in this iteration. It is ok for this number to be + // greater than totalParts because we actually cap it at totalParts in runJob. + var numPartsToTry = 1 + if (partsScanned > 0) { + // If we didn't find any rows after the first iteration, just try all partitions next. + // Otherwise, interpolate the number of partitions we need to try, but overestimate it + // by 50%. + if (buf.size == 0) { + numPartsToTry = totalParts - 1 + } else { + numPartsToTry = (1.5 * n * partsScanned / buf.size).toInt + } + } + numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions + + val left = n - buf.size + val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + val sc = sqlContext.sparkContext + val res = + sc.runJob(childRDD, (it: Iterator[Row]) => it.take(left).toArray, p, allowLocal = false) + + res.foreach(buf ++= _.take(n - buf.size)) + partsScanned += numPartsToTry + } + + buf.toArray.map(ScalaReflection.convertRowToScala(_, this.schema)) + } protected def newProjection( expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ff0609d4b3..0c77d399b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.{SQLContext, Strategy, execution} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.parquet._ @@ -284,13 +284,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil case logical.Sample(fraction, withReplacement, seed, child) => execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil - case SparkLogicalPlan(alreadyPlanned) => alreadyPlanned :: Nil case logical.LocalRelation(output, data) => - val nPartitions = if (data.isEmpty) 1 else numPartitions - PhysicalRDD( - output, - RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions), - StructType.fromAttributes(output))) :: Nil + LocalTableScan(output, data) :: Nil case logical.Limit(IntegerLiteral(limit), child) => execution.Limit(limit, planLater(child)) :: Nil case Unions(unionChildren) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 16ca4be558..66aed5d511 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -103,49 +103,7 @@ case class Limit(limit: Int, child: SparkPlan) override def output = child.output override def outputPartitioning = SinglePartition - /** - * A custom implementation modeled after the take function on RDDs but which never runs any job - * locally. This is to avoid shipping an entire partition of data in order to retrieve only a few - * rows. - */ - override def executeCollect(): Array[Row] = { - if (limit == 0) { - return new Array[Row](0) - } - - val childRDD = child.execute().map(_.copy()) - - val buf = new ArrayBuffer[Row] - val totalParts = childRDD.partitions.length - var partsScanned = 0 - while (buf.size < limit && partsScanned < totalParts) { - // The number of partitions to try in this iteration. It is ok for this number to be - // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1 - if (partsScanned > 0) { - // If we didn't find any rows after the first iteration, just try all partitions next. - // Otherwise, interpolate the number of partitions we need to try, but overestimate it - // by 50%. - if (buf.size == 0) { - numPartsToTry = totalParts - 1 - } else { - numPartsToTry = (1.5 * limit * partsScanned / buf.size).toInt - } - } - numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions - - val left = limit - buf.size - val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) - val sc = sqlContext.sparkContext - val res = - sc.runJob(childRDD, (it: Iterator[Row]) => it.take(left).toArray, p, allowLocal = false) - - res.foreach(buf ++= _.take(limit - buf.size)) - partsScanned += numPartsToTry - } - - buf.toArray.map(ScalaReflection.convertRowToScala(_, this.schema)) - } + override def executeCollect(): Array[Row] = child.executeTake(limit) override def execute() = { val rdd: RDD[_ <: Product2[Boolean, Row]] = if (sortBasedShuffleOn) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index e1c9a2be7d..1bc53968c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -58,6 +58,8 @@ case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan { override def executeCollect(): Array[Row] = sideEffectResult.toArray + override def executeTake(limit: Int): Array[Row] = sideEffectResult.take(limit).toArray + override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1) } -- cgit v1.2.3