aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@apache.org>2014-06-20 22:49:48 -0700
committerReynold Xin <rxin@apache.org>2014-06-20 22:49:48 -0700
commitca5d8b5904dc6dd5b691af506d3a842e508b3673 (patch)
tree466037fe69788109de940b13b97e6fcecdb567f0 /sql
parent648553d48ee1f830406750b50ec4cc322bcf47fe (diff)
downloadspark-ca5d8b5904dc6dd5b691af506d3a842e508b3673.tar.gz
spark-ca5d8b5904dc6dd5b691af506d3a842e508b3673.tar.bz2
spark-ca5d8b5904dc6dd5b691af506d3a842e508b3673.zip
[SQL] Pass SQLContext instead of SparkContext into physical operators.
This makes it easier to use config options in operators. Author: Reynold Xin <rxin@apache.org> Closes #1164 from rxin/sqlcontext and squashes the following commits: 797b2fd [Reynold Xin] Pass SQLContext instead of SparkContext into physical operators.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala21
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala2
7 files changed, 51 insertions, 44 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index ab376e5504..c60af28b2a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -221,7 +221,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
}
protected[sql] class SparkPlanner extends SparkStrategies {
- val sparkContext = self.sparkContext
+ val sparkContext: SparkContext = self.sparkContext
+
+ val sqlContext: SQLContext = self
def numPartitions = self.numShufflePartitions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
index 34d88fe4bd..d85d2d7844 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
@@ -24,6 +24,7 @@ import org.apache.spark.SparkContext
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.SQLContext
/**
* :: DeveloperApi ::
@@ -41,7 +42,7 @@ case class Aggregate(
partial: Boolean,
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
- child: SparkPlan)(@transient sc: SparkContext)
+ child: SparkPlan)(@transient sqlContext: SQLContext)
extends UnaryNode with NoBind {
override def requiredChildDistribution =
@@ -55,7 +56,7 @@ case class Aggregate(
}
}
- override def otherCopyArgs = sc :: Nil
+ override def otherCopyArgs = sqlContext :: Nil
// HACK: Generators don't correctly preserve their output through serializations so we grab
// out child's output attributes statically here.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 4694f25d6d..bd8ae4cdde 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -40,7 +40,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// no predicate can be evaluated by matching hash keys
case logical.Join(left, right, LeftSemi, condition) =>
execution.LeftSemiJoinBNL(
- planLater(left), planLater(right), condition)(sparkContext) :: Nil
+ planLater(left), planLater(right), condition)(sqlContext) :: Nil
case _ => Nil
}
}
@@ -103,7 +103,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
partial = true,
groupingExpressions,
partialComputation,
- planLater(child))(sparkContext))(sparkContext) :: Nil
+ planLater(child))(sqlContext))(sqlContext) :: Nil
} else {
Nil
}
@@ -115,7 +115,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Join(left, right, joinType, condition) =>
execution.BroadcastNestedLoopJoin(
- planLater(left), planLater(right), joinType, condition)(sparkContext) :: Nil
+ planLater(left), planLater(right), joinType, condition)(sqlContext) :: Nil
case _ => Nil
}
}
@@ -143,7 +143,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object TakeOrdered extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) =>
- execution.TakeOrdered(limit, order, planLater(child))(sparkContext) :: Nil
+ execution.TakeOrdered(limit, order, planLater(child))(sqlContext) :: Nil
case _ => Nil
}
}
@@ -155,9 +155,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
val relation =
ParquetRelation.create(path, child, sparkContext.hadoopConfiguration)
// Note: overwrite=false because otherwise the metadata we just created will be deleted
- InsertIntoParquetTable(relation, planLater(child), overwrite=false)(sparkContext) :: Nil
+ InsertIntoParquetTable(relation, planLater(child), overwrite=false)(sqlContext) :: Nil
case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) =>
- InsertIntoParquetTable(table, planLater(child), overwrite)(sparkContext) :: Nil
+ InsertIntoParquetTable(table, planLater(child), overwrite)(sqlContext) :: Nil
case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) =>
val prunePushedDownFilters =
if (sparkContext.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) {
@@ -186,7 +186,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
projectList,
filters,
prunePushedDownFilters,
- ParquetTableScan(_, relation, filters)(sparkContext)) :: Nil
+ ParquetTableScan(_, relation, filters)(sqlContext)) :: Nil
case _ => Nil
}
@@ -211,7 +211,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Distinct(child) =>
execution.Aggregate(
- partial = false, child.output, child.output, planLater(child))(sparkContext) :: Nil
+ partial = false, child.output, child.output, planLater(child))(sqlContext) :: Nil
case logical.Sort(sortExprs, child) =>
// This sort is a global sort. Its requiredDistribution will be an OrderedDistribution.
execution.Sort(sortExprs, global = true, planLater(child)):: Nil
@@ -224,7 +224,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.Filter(condition, child) =>
execution.Filter(condition, planLater(child)) :: Nil
case logical.Aggregate(group, agg, child) =>
- execution.Aggregate(partial = false, group, agg, planLater(child))(sparkContext) :: Nil
+ execution.Aggregate(partial = false, group, agg, planLater(child))(sqlContext) :: Nil
case logical.Sample(fraction, withReplacement, seed, child) =>
execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data) =>
@@ -233,9 +233,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
new GenericRow(r.productIterator.map(convertToCatalyst).toArray): Row))
execution.ExistingRdd(output, dataAsRdd) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
- execution.Limit(limit, planLater(child))(sparkContext) :: Nil
+ execution.Limit(limit, planLater(child))(sqlContext) :: Nil
case Unions(unionChildren) =>
- execution.Union(unionChildren.map(planLater))(sparkContext) :: Nil
+ execution.Union(unionChildren.map(planLater))(sqlContext) :: Nil
case logical.Generate(generator, join, outer, _, child) =>
execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil
case logical.NoRelation =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 8969794c69..18f4a5877b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -20,8 +20,9 @@ package org.apache.spark.sql.execution
import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.{HashPartitioner, SparkConf, SparkContext}
+import org.apache.spark.{HashPartitioner, SparkConf}
import org.apache.spark.rdd.{RDD, ShuffledRDD}
+import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
@@ -70,12 +71,12 @@ case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child:
* :: DeveloperApi ::
*/
@DeveloperApi
-case class Union(children: Seq[SparkPlan])(@transient sc: SparkContext) extends SparkPlan {
+case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) extends SparkPlan {
// TODO: attributes output by union should be distinct for nullability purposes
override def output = children.head.output
- override def execute() = sc.union(children.map(_.execute()))
+ override def execute() = sqlContext.sparkContext.union(children.map(_.execute()))
- override def otherCopyArgs = sc :: Nil
+ override def otherCopyArgs = sqlContext :: Nil
}
/**
@@ -87,11 +88,12 @@ case class Union(children: Seq[SparkPlan])(@transient sc: SparkContext) extends
* data to a single partition to compute the global limit.
*/
@DeveloperApi
-case class Limit(limit: Int, child: SparkPlan)(@transient sc: SparkContext) extends UnaryNode {
+case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext)
+ extends UnaryNode {
// TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan:
// partition local limit -> exchange into one partition -> partition local limit again
- override def otherCopyArgs = sc :: Nil
+ override def otherCopyArgs = sqlContext :: Nil
override def output = child.output
@@ -117,8 +119,8 @@ case class Limit(limit: Int, child: SparkPlan)(@transient sc: SparkContext) exte
*/
@DeveloperApi
case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
- (@transient sc: SparkContext) extends UnaryNode {
- override def otherCopyArgs = sc :: Nil
+ (@transient sqlContext: SQLContext) extends UnaryNode {
+ override def otherCopyArgs = sqlContext :: Nil
override def output = child.output
@@ -129,7 +131,7 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
// TODO: Terminal split should be implemented differently from non-terminal split.
// TODO: Pick num splits based on |limit|.
- override def execute() = sc.makeRDD(executeCollect(), 1)
+ override def execute() = sqlContext.sparkContext.makeRDD(executeCollect(), 1)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
index 8d7a5ba59f..84bdde38b7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
@@ -19,9 +19,8 @@ package org.apache.spark.sql.execution
import scala.collection.mutable.{ArrayBuffer, BitSet}
-import org.apache.spark.SparkContext
-
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning}
@@ -200,13 +199,13 @@ case class LeftSemiJoinHash(
@DeveloperApi
case class LeftSemiJoinBNL(
streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression])
- (@transient sc: SparkContext)
+ (@transient sqlContext: SQLContext)
extends BinaryNode {
// TODO: Override requiredChildDistribution.
override def outputPartitioning: Partitioning = streamed.outputPartitioning
- override def otherCopyArgs = sc :: Nil
+ override def otherCopyArgs = sqlContext :: Nil
def output = left.output
@@ -223,7 +222,8 @@ case class LeftSemiJoinBNL(
def execute() = {
- val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
+ val broadcastedRelation =
+ sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
streamed.execute().mapPartitions { streamedIter =>
val joinedRow = new JoinedRow
@@ -263,13 +263,13 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod
@DeveloperApi
case class BroadcastNestedLoopJoin(
streamed: SparkPlan, broadcast: SparkPlan, joinType: JoinType, condition: Option[Expression])
- (@transient sc: SparkContext)
+ (@transient sqlContext: SQLContext)
extends BinaryNode {
// TODO: Override requiredChildDistribution.
override def outputPartitioning: Partitioning = streamed.outputPartitioning
- override def otherCopyArgs = sc :: Nil
+ override def otherCopyArgs = sqlContext :: Nil
def output = left.output ++ right.output
@@ -286,7 +286,8 @@ case class BroadcastNestedLoopJoin(
def execute() = {
- val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
+ val broadcastedRelation =
+ sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter =>
val matchedRows = new ArrayBuffer[Row]
@@ -337,7 +338,7 @@ case class BroadcastNestedLoopJoin(
}
// TODO: Breaks lineage.
- sc.union(
- streamedPlusMatches.flatMap(_._1), sc.makeRDD(rightOuterMatches))
+ sqlContext.sparkContext.union(
+ streamedPlusMatches.flatMap(_._1), sqlContext.sparkContext.makeRDD(rightOuterMatches))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
index 624f2e2fa1..ade823b51c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
@@ -33,10 +33,10 @@ import parquet.hadoop.util.ContextUtil
import parquet.io.InvalidRecordException
import parquet.schema.MessageType
-import org.apache.spark.{Logging, SerializableWritable, SparkContext, TaskContext}
+import org.apache.spark.{Logging, SerializableWritable, TaskContext}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row}
-import org.apache.spark.sql.catalyst.types.StructType
import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode}
/**
@@ -49,10 +49,11 @@ case class ParquetTableScan(
output: Seq[Attribute],
relation: ParquetRelation,
columnPruningPred: Seq[Expression])(
- @transient val sc: SparkContext)
+ @transient val sqlContext: SQLContext)
extends LeafNode {
override def execute(): RDD[Row] = {
+ val sc = sqlContext.sparkContext
val job = new Job(sc.hadoopConfiguration)
ParquetInputFormat.setReadSupportClass(
job,
@@ -93,7 +94,7 @@ case class ParquetTableScan(
.filter(_ != null) // Parquet's record filters may produce null values
}
- override def otherCopyArgs = sc :: Nil
+ override def otherCopyArgs = sqlContext :: Nil
/**
* Applies a (candidate) projection.
@@ -104,7 +105,7 @@ case class ParquetTableScan(
def pruneColumns(prunedAttributes: Seq[Attribute]): ParquetTableScan = {
val success = validateProjection(prunedAttributes)
if (success) {
- ParquetTableScan(prunedAttributes, relation, columnPruningPred)(sc)
+ ParquetTableScan(prunedAttributes, relation, columnPruningPred)(sqlContext)
} else {
sys.error("Warning: Could not validate Parquet schema projection in pruneColumns")
this
@@ -152,7 +153,7 @@ case class InsertIntoParquetTable(
relation: ParquetRelation,
child: SparkPlan,
overwrite: Boolean = false)(
- @transient val sc: SparkContext)
+ @transient val sqlContext: SQLContext)
extends UnaryNode with SparkHadoopMapReduceUtil {
/**
@@ -168,7 +169,7 @@ case class InsertIntoParquetTable(
val childRdd = child.execute()
assert(childRdd != null)
- val job = new Job(sc.hadoopConfiguration)
+ val job = new Job(sqlContext.sparkContext.hadoopConfiguration)
val writeSupport =
if (child.output.map(_.dataType).forall(_.isPrimitive)) {
@@ -204,7 +205,7 @@ case class InsertIntoParquetTable(
override def output = child.output
- override def otherCopyArgs = sc :: Nil
+ override def otherCopyArgs = sqlContext :: Nil
/**
* Stores the given Row RDD as a Hadoop file.
@@ -231,7 +232,7 @@ case class InsertIntoParquetTable(
val wrappedConf = new SerializableWritable(job.getConfiguration)
val formatter = new SimpleDateFormat("yyyyMMddHHmm")
val jobtrackerID = formatter.format(new Date())
- val stageId = sc.newRddId()
+ val stageId = sqlContext.sparkContext.newRddId()
val taskIdOffset =
if (overwrite) {
@@ -270,7 +271,7 @@ case class InsertIntoParquetTable(
val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId)
val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext)
jobCommitter.setupJob(jobTaskContext)
- sc.runJob(rdd, writeShard _)
+ sqlContext.sparkContext.runJob(rdd, writeShard _)
jobCommitter.commitJob(jobTaskContext)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index 7714eb1b56..2ca0c1cdcb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -166,7 +166,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
val scanner = new ParquetTableScan(
ParquetTestData.testData.output,
ParquetTestData.testData,
- Seq())(TestSQLContext.sparkContext)
+ Seq())(TestSQLContext)
val projected = scanner.pruneColumns(ParquetTypesConverter
.convertToAttributes(MessageTypeParser
.parseMessageType(ParquetTestData.subTestSchema)))