aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-02-04 19:53:57 -0800
committerReynold Xin <rxin@databricks.com>2015-02-04 19:53:57 -0800
commit84acd08e0886aa23195f35837c15c09aa7804aff (patch)
tree85ccf925a6a6123463afdfb15e9913c953704ea5 /sql/core
parent206f9bc3622348926d73e43c8010519f7df9b34f (diff)
downloadspark-84acd08e0886aa23195f35837c15c09aa7804aff.tar.gz
spark-84acd08e0886aa23195f35837c15c09aa7804aff.tar.bz2
spark-84acd08e0886aa23195f35837c15c09aa7804aff.zip
[SPARK-5602][SQL] Better support for creating DataFrame from local data collection
1. Added methods to create DataFrames from Seq[Product] 2. Added executeTake to avoid running a Spark job on LocalRelations. Author: Reynold Xin <rxin@databricks.com> Closes #4372 from rxin/localDataFrame and squashes the following commits: f696858 [Reynold Xin] style checker. 839ef7f [Reynold Xin] [SPARK-5602][SQL] Better support for creating DataFrame from local data collection.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala41
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala36
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala36
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala49
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala44
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala2
7 files changed, 140 insertions, 77 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 1661282fc3..5ab5494f80 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.execution._
import org.apache.spark.sql.json._
@@ -163,17 +163,52 @@ class SQLContext(@transient val sparkContext: SparkContext)
/** Removes the specified table from the in-memory cache. */
def uncacheTable(tableName: String): Unit = cacheManager.uncacheTable(tableName)
+ // scalastyle:off
+ // Disable style checker so "implicits" object can start with lowercase i
+ /**
+ * Implicit methods available in Scala for converting common Scala objects into [[DataFrame]]s.
+ */
+ object implicits {
+ // scalastyle:on
+ /**
+ * Creates a DataFrame from an RDD of case classes.
+ *
+ * @group userf
+ */
+ implicit def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
+ self.createDataFrame(rdd)
+ }
+
+ /**
+ * Creates a DataFrame from a local Seq of Product.
+ */
+ implicit def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
+ self.createDataFrame(data)
+ }
+ }
+
/**
* Creates a DataFrame from an RDD of case classes.
*
* @group userf
*/
- implicit def createDataFrame[A <: Product: TypeTag](rdd: RDD[A]): DataFrame = {
+ // TODO: Remove implicit here.
+ implicit def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
SparkPlan.currentContext.set(self)
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
val attributeSeq = schema.toAttributes
val rowRDD = RDDConversions.productToRowRdd(rdd, schema)
- DataFrame(this, LogicalRDD(attributeSeq, rowRDD)(self))
+ DataFrame(self, LogicalRDD(attributeSeq, rowRDD)(self))
+ }
+
+ /**
+ * Creates a DataFrame from a local Seq of Product.
+ */
+ def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
+ SparkPlan.currentContext.set(self)
+ val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
+ val attributeSeq = schema.toAttributes
+ DataFrame(self, LocalRelation.fromProduct(attributeSeq, data))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 20b14834bb..248dc1512b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -54,12 +54,13 @@ object RDDConversions {
}
}
+/** Logical plan node for scanning data from an RDD. */
case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLContext)
extends LogicalPlan with MultiInstanceRelation {
- def children = Nil
+ override def children = Nil
- def newInstance() =
+ override def newInstance() =
LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type]
override def sameResult(plan: LogicalPlan) = plan match {
@@ -74,39 +75,28 @@ case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLCont
)
}
+/** Physical plan node for scanning data from an RDD. */
case class PhysicalRDD(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
override def execute() = rdd
}
-@deprecated("Use LogicalRDD", "1.2.0")
-case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
- override def execute() = rdd
-}
-
-@deprecated("Use LogicalRDD", "1.2.0")
-case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQLContext)
- extends LogicalPlan with MultiInstanceRelation {
+/** Logical plan node for scanning data from a local collection. */
+case class LogicalLocalTable(output: Seq[Attribute], rows: Seq[Row])(sqlContext: SQLContext)
+ extends LogicalPlan with MultiInstanceRelation {
- def output = alreadyPlanned.output
override def children = Nil
- override final def newInstance(): this.type = {
- SparkLogicalPlan(
- alreadyPlanned match {
- case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance()), rdd)
- case _ => sys.error("Multiple instance of the same relation detected.")
- })(sqlContext).asInstanceOf[this.type]
- }
+ override def newInstance() =
+ LogicalLocalTable(output.map(_.newInstance()), rows)(sqlContext).asInstanceOf[this.type]
override def sameResult(plan: LogicalPlan) = plan match {
- case SparkLogicalPlan(ExistingRdd(_, rdd)) =>
- rdd.id == alreadyPlanned.asInstanceOf[ExistingRdd].rdd.id
+ case LogicalRDD(_, otherRDD) => rows == rows
case _ => false
}
@transient override lazy val statistics = Statistics(
- // TODO: Instead of returning a default value here, find a way to return a meaningful size
- // estimate for RDDs. See PR 1238 for more discussions.
- sizeInBytes = BigInt(sqlContext.conf.defaultSizeInBytes)
+ // TODO: Improve the statistics estimation.
+ // This is made small enough so it can be broadcasted.
+ sizeInBytes = sqlContext.conf.autoBroadcastJoinThreshold - 1
)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala
new file mode 100644
index 0000000000..d6d8258f46
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.Attribute
+
+
+/**
+ * Physical plan node for scanning data from a local collection.
+ */
+case class LocalTableScan(output: Seq[Attribute], rows: Seq[Row]) extends LeafNode {
+
+ private lazy val rdd = sqlContext.sparkContext.parallelize(rows)
+
+ override def execute() = rdd
+
+ override def executeCollect() = rows.toArray
+
+ override def executeTake(limit: Int) = rows.take(limit).toArray
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 6fecd1ff06..052766c20a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -27,6 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical._
+import scala.collection.mutable.ArrayBuffer
+
object SparkPlan {
protected[sql] val currentContext = new ThreadLocal[SQLContext]()
}
@@ -77,8 +79,53 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/**
* Runs this query returning the result as an array.
*/
- def executeCollect(): Array[Row] =
+ def executeCollect(): Array[Row] = {
execute().map(ScalaReflection.convertRowToScala(_, schema)).collect()
+ }
+
+ /**
+ * Runs this query returning the first `n` rows as an array.
+ *
+ * This is modeled after RDD.take but never runs any job locally on the driver.
+ */
+ def executeTake(n: Int): Array[Row] = {
+ if (n == 0) {
+ return new Array[Row](0)
+ }
+
+ val childRDD = execute().map(_.copy())
+
+ val buf = new ArrayBuffer[Row]
+ val totalParts = childRDD.partitions.length
+ var partsScanned = 0
+ while (buf.size < n && partsScanned < totalParts) {
+ // The number of partitions to try in this iteration. It is ok for this number to be
+ // greater than totalParts because we actually cap it at totalParts in runJob.
+ var numPartsToTry = 1
+ if (partsScanned > 0) {
+ // If we didn't find any rows after the first iteration, just try all partitions next.
+ // Otherwise, interpolate the number of partitions we need to try, but overestimate it
+ // by 50%.
+ if (buf.size == 0) {
+ numPartsToTry = totalParts - 1
+ } else {
+ numPartsToTry = (1.5 * n * partsScanned / buf.size).toInt
+ }
+ }
+ numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
+
+ val left = n - buf.size
+ val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
+ val sc = sqlContext.sparkContext
+ val res =
+ sc.runJob(childRDD, (it: Iterator[Row]) => it.take(left).toArray, p, allowLocal = false)
+
+ res.foreach(buf ++= _.take(n - buf.size))
+ partsScanned += numPartsToTry
+ }
+
+ buf.toArray.map(ScalaReflection.convertRowToScala(_, this.schema))
+ }
protected def newProjection(
expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index ff0609d4b3..0c77d399b2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.{SQLContext, Strategy, execution}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
import org.apache.spark.sql.parquet._
@@ -284,13 +284,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
case logical.Sample(fraction, withReplacement, seed, child) =>
execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
- case SparkLogicalPlan(alreadyPlanned) => alreadyPlanned :: Nil
case logical.LocalRelation(output, data) =>
- val nPartitions = if (data.isEmpty) 1 else numPartitions
- PhysicalRDD(
- output,
- RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions),
- StructType.fromAttributes(output))) :: Nil
+ LocalTableScan(output, data) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
execution.Limit(limit, planLater(child)) :: Nil
case Unions(unionChildren) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 16ca4be558..66aed5d511 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -103,49 +103,7 @@ case class Limit(limit: Int, child: SparkPlan)
override def output = child.output
override def outputPartitioning = SinglePartition
- /**
- * A custom implementation modeled after the take function on RDDs but which never runs any job
- * locally. This is to avoid shipping an entire partition of data in order to retrieve only a few
- * rows.
- */
- override def executeCollect(): Array[Row] = {
- if (limit == 0) {
- return new Array[Row](0)
- }
-
- val childRDD = child.execute().map(_.copy())
-
- val buf = new ArrayBuffer[Row]
- val totalParts = childRDD.partitions.length
- var partsScanned = 0
- while (buf.size < limit && partsScanned < totalParts) {
- // The number of partitions to try in this iteration. It is ok for this number to be
- // greater than totalParts because we actually cap it at totalParts in runJob.
- var numPartsToTry = 1
- if (partsScanned > 0) {
- // If we didn't find any rows after the first iteration, just try all partitions next.
- // Otherwise, interpolate the number of partitions we need to try, but overestimate it
- // by 50%.
- if (buf.size == 0) {
- numPartsToTry = totalParts - 1
- } else {
- numPartsToTry = (1.5 * limit * partsScanned / buf.size).toInt
- }
- }
- numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
-
- val left = limit - buf.size
- val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
- val sc = sqlContext.sparkContext
- val res =
- sc.runJob(childRDD, (it: Iterator[Row]) => it.take(left).toArray, p, allowLocal = false)
-
- res.foreach(buf ++= _.take(limit - buf.size))
- partsScanned += numPartsToTry
- }
-
- buf.toArray.map(ScalaReflection.convertRowToScala(_, this.schema))
- }
+ override def executeCollect(): Array[Row] = child.executeTake(limit)
override def execute() = {
val rdd: RDD[_ <: Product2[Boolean, Row]] = if (sortBasedShuffleOn) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index e1c9a2be7d..1bc53968c4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -58,6 +58,8 @@ case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan {
override def executeCollect(): Array[Row] = sideEffectResult.toArray
+ override def executeTake(limit: Int): Array[Row] = sideEffectResult.take(limit).toArray
+
override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1)
}