aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-02-04 19:53:57 -0800
committerReynold Xin <rxin@databricks.com>2015-02-04 19:53:57 -0800
commit84acd08e0886aa23195f35837c15c09aa7804aff (patch)
tree85ccf925a6a6123463afdfb15e9913c953704ea5 /sql
parent206f9bc3622348926d73e43c8010519f7df9b34f (diff)
downloadspark-84acd08e0886aa23195f35837c15c09aa7804aff.tar.gz
spark-84acd08e0886aa23195f35837c15c09aa7804aff.tar.bz2
spark-84acd08e0886aa23195f35837c15c09aa7804aff.zip
[SPARK-5602][SQL] Better support for creating DataFrame from local data collection
1. Added methods to create DataFrames from Seq[Product] 2. Added executeTake to avoid running a Spark job on LocalRelations. Author: Reynold Xin <rxin@databricks.com> Closes #4372 from rxin/localDataFrame and squashes the following commits: f696858 [Reynold Xin] style checker. 839ef7f [Reynold Xin] [SPARK-5602][SQL] Better support for creating DataFrame from local data collection.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala (renamed from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala)23
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala41
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala36
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala36
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala49
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala44
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala2
10 files changed, 170 insertions, 88 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 8e79e532ca..0445f3aa07 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -211,7 +211,7 @@ trait ScalaReflection {
*/
def asRelation: LocalRelation = {
val output = attributesFor[A]
- LocalRelation(output, data)
+ LocalRelation.fromProduct(output, data)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
index d90af45b37..92bd057c6f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
@@ -17,31 +17,34 @@
package org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.types.{StructType, StructField}
+import org.apache.spark.sql.types.{DataTypeConversions, StructType, StructField}
object LocalRelation {
def apply(output: Attribute*): LocalRelation = new LocalRelation(output)
- def apply(output1: StructField, output: StructField*): LocalRelation = new LocalRelation(
- StructType(output1 +: output).toAttributes
- )
+ def apply(output1: StructField, output: StructField*): LocalRelation = {
+ new LocalRelation(StructType(output1 +: output).toAttributes)
+ }
+
+ def fromProduct(output: Seq[Attribute], data: Seq[Product]): LocalRelation = {
+ val schema = StructType.fromAttributes(output)
+ LocalRelation(output, data.map(row => DataTypeConversions.productToRow(row, schema)))
+ }
}
-case class LocalRelation(output: Seq[Attribute], data: Seq[Product] = Nil)
+case class LocalRelation(output: Seq[Attribute], data: Seq[Row] = Nil)
extends LeafNode with analysis.MultiInstanceRelation {
- // TODO: Validate schema compliance.
- def loadData(newData: Seq[Product]) = new LocalRelation(output, data ++ newData)
-
/**
* Returns an identical copy of this relation with new exprIds for all attributes. Different
* attributes are required when a relation is going to be included multiple times in the same
* query.
*/
- override final def newInstance: this.type = {
- LocalRelation(output.map(_.newInstance), data).asInstanceOf[this.type]
+ override final def newInstance(): this.type = {
+ LocalRelation(output.map(_.newInstance()), data).asInstanceOf[this.type]
}
override protected def stringArgs = Iterator(output)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala
index 21f478c80c..c243be07a9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala
@@ -19,11 +19,27 @@ package org.apache.spark.sql.types
import java.text.SimpleDateFormat
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
protected[sql] object DataTypeConversions {
+ def productToRow(product: Product, schema: StructType): Row = {
+ val mutableRow = new GenericMutableRow(product.productArity)
+ val schemaFields = schema.fields.toArray
+
+ var i = 0
+ while (i < mutableRow.length) {
+ mutableRow(i) =
+ ScalaReflection.convertToCatalyst(product.productElement(i), schemaFields(i).dataType)
+ i += 1
+ }
+
+ mutableRow
+ }
+
def stringToTime(s: String): java.util.Date = {
if (!s.contains('T')) {
// JDBC escape string
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 1661282fc3..5ab5494f80 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.execution._
import org.apache.spark.sql.json._
@@ -163,17 +163,52 @@ class SQLContext(@transient val sparkContext: SparkContext)
/** Removes the specified table from the in-memory cache. */
def uncacheTable(tableName: String): Unit = cacheManager.uncacheTable(tableName)
+ // scalastyle:off
+ // Disable style checker so "implicits" object can start with lowercase i
+ /**
+ * Implicit methods available in Scala for converting common Scala objects into [[DataFrame]]s.
+ */
+ object implicits {
+ // scalastyle:on
+ /**
+ * Creates a DataFrame from an RDD of case classes.
+ *
+ * @group userf
+ */
+ implicit def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
+ self.createDataFrame(rdd)
+ }
+
+ /**
+ * Creates a DataFrame from a local Seq of Product.
+ */
+ implicit def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
+ self.createDataFrame(data)
+ }
+ }
+
/**
* Creates a DataFrame from an RDD of case classes.
*
* @group userf
*/
- implicit def createDataFrame[A <: Product: TypeTag](rdd: RDD[A]): DataFrame = {
+ // TODO: Remove implicit here.
+ implicit def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
SparkPlan.currentContext.set(self)
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
val attributeSeq = schema.toAttributes
val rowRDD = RDDConversions.productToRowRdd(rdd, schema)
- DataFrame(this, LogicalRDD(attributeSeq, rowRDD)(self))
+ DataFrame(self, LogicalRDD(attributeSeq, rowRDD)(self))
+ }
+
+ /**
+ * Creates a DataFrame from a local Seq of Product.
+ */
+ def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
+ SparkPlan.currentContext.set(self)
+ val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
+ val attributeSeq = schema.toAttributes
+ DataFrame(self, LocalRelation.fromProduct(attributeSeq, data))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 20b14834bb..248dc1512b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -54,12 +54,13 @@ object RDDConversions {
}
}
+/** Logical plan node for scanning data from an RDD. */
case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLContext)
extends LogicalPlan with MultiInstanceRelation {
- def children = Nil
+ override def children = Nil
- def newInstance() =
+ override def newInstance() =
LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type]
override def sameResult(plan: LogicalPlan) = plan match {
@@ -74,39 +75,28 @@ case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLCont
)
}
+/** Physical plan node for scanning data from an RDD. */
case class PhysicalRDD(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
override def execute() = rdd
}
-@deprecated("Use LogicalRDD", "1.2.0")
-case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
- override def execute() = rdd
-}
-
-@deprecated("Use LogicalRDD", "1.2.0")
-case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQLContext)
- extends LogicalPlan with MultiInstanceRelation {
+/** Logical plan node for scanning data from a local collection. */
+case class LogicalLocalTable(output: Seq[Attribute], rows: Seq[Row])(sqlContext: SQLContext)
+ extends LogicalPlan with MultiInstanceRelation {
- def output = alreadyPlanned.output
override def children = Nil
- override final def newInstance(): this.type = {
- SparkLogicalPlan(
- alreadyPlanned match {
- case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance()), rdd)
- case _ => sys.error("Multiple instance of the same relation detected.")
- })(sqlContext).asInstanceOf[this.type]
- }
+ override def newInstance() =
+ LogicalLocalTable(output.map(_.newInstance()), rows)(sqlContext).asInstanceOf[this.type]
override def sameResult(plan: LogicalPlan) = plan match {
- case SparkLogicalPlan(ExistingRdd(_, rdd)) =>
- rdd.id == alreadyPlanned.asInstanceOf[ExistingRdd].rdd.id
+ case LogicalRDD(_, otherRDD) => rows == rows
case _ => false
}
@transient override lazy val statistics = Statistics(
- // TODO: Instead of returning a default value here, find a way to return a meaningful size
- // estimate for RDDs. See PR 1238 for more discussions.
- sizeInBytes = BigInt(sqlContext.conf.defaultSizeInBytes)
+ // TODO: Improve the statistics estimation.
+ // This is made small enough so it can be broadcasted.
+ sizeInBytes = sqlContext.conf.autoBroadcastJoinThreshold - 1
)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala
new file mode 100644
index 0000000000..d6d8258f46
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.Attribute
+
+
+/**
+ * Physical plan node for scanning data from a local collection.
+ */
+case class LocalTableScan(output: Seq[Attribute], rows: Seq[Row]) extends LeafNode {
+
+ private lazy val rdd = sqlContext.sparkContext.parallelize(rows)
+
+ override def execute() = rdd
+
+ override def executeCollect() = rows.toArray
+
+ override def executeTake(limit: Int) = rows.take(limit).toArray
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 6fecd1ff06..052766c20a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -27,6 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical._
+import scala.collection.mutable.ArrayBuffer
+
object SparkPlan {
protected[sql] val currentContext = new ThreadLocal[SQLContext]()
}
@@ -77,8 +79,53 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/**
* Runs this query returning the result as an array.
*/
- def executeCollect(): Array[Row] =
+ def executeCollect(): Array[Row] = {
execute().map(ScalaReflection.convertRowToScala(_, schema)).collect()
+ }
+
+ /**
+ * Runs this query returning the first `n` rows as an array.
+ *
+ * This is modeled after RDD.take but never runs any job locally on the driver.
+ */
+ def executeTake(n: Int): Array[Row] = {
+ if (n == 0) {
+ return new Array[Row](0)
+ }
+
+ val childRDD = execute().map(_.copy())
+
+ val buf = new ArrayBuffer[Row]
+ val totalParts = childRDD.partitions.length
+ var partsScanned = 0
+ while (buf.size < n && partsScanned < totalParts) {
+ // The number of partitions to try in this iteration. It is ok for this number to be
+ // greater than totalParts because we actually cap it at totalParts in runJob.
+ var numPartsToTry = 1
+ if (partsScanned > 0) {
+ // If we didn't find any rows after the first iteration, just try all partitions next.
+ // Otherwise, interpolate the number of partitions we need to try, but overestimate it
+ // by 50%.
+ if (buf.size == 0) {
+ numPartsToTry = totalParts - 1
+ } else {
+ numPartsToTry = (1.5 * n * partsScanned / buf.size).toInt
+ }
+ }
+ numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
+
+ val left = n - buf.size
+ val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
+ val sc = sqlContext.sparkContext
+ val res =
+ sc.runJob(childRDD, (it: Iterator[Row]) => it.take(left).toArray, p, allowLocal = false)
+
+ res.foreach(buf ++= _.take(n - buf.size))
+ partsScanned += numPartsToTry
+ }
+
+ buf.toArray.map(ScalaReflection.convertRowToScala(_, this.schema))
+ }
protected def newProjection(
expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index ff0609d4b3..0c77d399b2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.{SQLContext, Strategy, execution}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
import org.apache.spark.sql.parquet._
@@ -284,13 +284,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
case logical.Sample(fraction, withReplacement, seed, child) =>
execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
- case SparkLogicalPlan(alreadyPlanned) => alreadyPlanned :: Nil
case logical.LocalRelation(output, data) =>
- val nPartitions = if (data.isEmpty) 1 else numPartitions
- PhysicalRDD(
- output,
- RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions),
- StructType.fromAttributes(output))) :: Nil
+ LocalTableScan(output, data) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
execution.Limit(limit, planLater(child)) :: Nil
case Unions(unionChildren) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 16ca4be558..66aed5d511 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -103,49 +103,7 @@ case class Limit(limit: Int, child: SparkPlan)
override def output = child.output
override def outputPartitioning = SinglePartition
- /**
- * A custom implementation modeled after the take function on RDDs but which never runs any job
- * locally. This is to avoid shipping an entire partition of data in order to retrieve only a few
- * rows.
- */
- override def executeCollect(): Array[Row] = {
- if (limit == 0) {
- return new Array[Row](0)
- }
-
- val childRDD = child.execute().map(_.copy())
-
- val buf = new ArrayBuffer[Row]
- val totalParts = childRDD.partitions.length
- var partsScanned = 0
- while (buf.size < limit && partsScanned < totalParts) {
- // The number of partitions to try in this iteration. It is ok for this number to be
- // greater than totalParts because we actually cap it at totalParts in runJob.
- var numPartsToTry = 1
- if (partsScanned > 0) {
- // If we didn't find any rows after the first iteration, just try all partitions next.
- // Otherwise, interpolate the number of partitions we need to try, but overestimate it
- // by 50%.
- if (buf.size == 0) {
- numPartsToTry = totalParts - 1
- } else {
- numPartsToTry = (1.5 * limit * partsScanned / buf.size).toInt
- }
- }
- numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
-
- val left = limit - buf.size
- val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
- val sc = sqlContext.sparkContext
- val res =
- sc.runJob(childRDD, (it: Iterator[Row]) => it.take(left).toArray, p, allowLocal = false)
-
- res.foreach(buf ++= _.take(limit - buf.size))
- partsScanned += numPartsToTry
- }
-
- buf.toArray.map(ScalaReflection.convertRowToScala(_, this.schema))
- }
+ override def executeCollect(): Array[Row] = child.executeTake(limit)
override def execute() = {
val rdd: RDD[_ <: Product2[Boolean, Row]] = if (sortBasedShuffleOn) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index e1c9a2be7d..1bc53968c4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -58,6 +58,8 @@ case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan {
override def executeCollect(): Array[Row] = sideEffectResult.toArray
+ override def executeTake(limit: Int): Array[Row] = sideEffectResult.take(limit).toArray
+
override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1)
}