diff options
author | Wenchen Fan <wenchen@databricks.com> | 2016-04-20 00:44:02 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-04-20 00:44:02 -0700 |
commit | 7abe9a6578a99f4df50040d5cfe083c389c7b97f (patch) | |
tree | 59f4ae404f2a0de4fae55da0849103999f87446b /sql/core/src/main | |
parent | 14869ae64eb27830179d4954a5dc3e0a1e1330b4 (diff) | |
download | spark-7abe9a6578a99f4df50040d5cfe083c389c7b97f.tar.gz spark-7abe9a6578a99f4df50040d5cfe083c389c7b97f.tar.bz2 spark-7abe9a6578a99f4df50040d5cfe083c389c7b97f.zip |
[SPARK-9013][SQL] generate MutableProjection directly instead of return a function
`MutableProjection` is not thread-safe and we won't use it in multiple threads. I think the reason that we return `() => MutableProjection` is not about thread safety, but to save the costs of generating code when we need same but individual mutable projections.
However, I only found one place that use this [feature](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala#L122-L123), and comparing to the troubles it brings, I think we should generate `MutableProjection` directly instead of return a function.
Author: Wenchen Fan <wenchen@databricks.com>
Closes #7373 from cloud-fan/project.
Diffstat (limited to 'sql/core/src/main')
8 files changed, 20 insertions, 22 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 415cd4d84a..b64352a9e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -352,12 +352,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } } - private[this] def isTesting: Boolean = sys.props.contains("spark.testing") - protected def newMutableProjection( expressions: Seq[Expression], inputSchema: Seq[Attribute], - useSubexprElimination: Boolean = false): () => MutableProjection = { + useSubexprElimination: Boolean = false): MutableProjection = { log.debug(s"Creating MutableProj: $expressions, inputSchema: $inputSchema") GenerateMutableProjection.generate(expressions, inputSchema, useSubexprElimination) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 8e9214fa25..85ce388de0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -120,14 +120,14 @@ case class Window( val (exprs, current, bound) = if (offset == 0) { // Use the entire order expression when the offset is 0. val exprs = orderSpec.map(_.child) - val projection = newMutableProjection(exprs, child.output) - (orderSpec, projection(), projection()) + val buildProjection = () => newMutableProjection(exprs, child.output) + (orderSpec, buildProjection(), buildProjection()) } else if (orderSpec.size == 1) { // Use only the first order expression when the offset is non-null. val sortExpr = orderSpec.head val expr = sortExpr.child // Create the projection which returns the current 'value'. - val current = newMutableProjection(expr :: Nil, child.output)() + val current = newMutableProjection(expr :: Nil, child.output) // Flip the sign of the offset when processing the order is descending val boundOffset = sortExpr.direction match { case Descending => -offset @@ -135,7 +135,7 @@ case class Window( } // Create the projection which returns the current 'value' modified by adding the offset. val boundExpr = Add(expr, Cast(Literal.create(boundOffset, IntegerType), expr.dataType)) - val bound = newMutableProjection(boundExpr :: Nil, child.output)() + val bound = newMutableProjection(boundExpr :: Nil, child.output) (sortExpr :: Nil, current, bound) } else { sys.error("Non-Zero range offsets are not supported for windows " + @@ -564,7 +564,7 @@ private[execution] final class OffsetWindowFunctionFrame( ordinal: Int, expressions: Array[Expression], inputSchema: Seq[Attribute], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => () => MutableProjection, + newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, offset: Int) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ @@ -604,7 +604,7 @@ private[execution] final class OffsetWindowFunctionFrame( } // Create the projection. - newMutableProjection(boundExpressions, Nil)().target(target) + newMutableProjection(boundExpressions, Nil).target(target) } override def prepare(rows: RowBuffer): Unit = { @@ -886,7 +886,7 @@ private[execution] object AggregateProcessor { functions: Array[Expression], ordinal: Int, inputAttributes: Seq[Attribute], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => () => MutableProjection): + newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection): AggregateProcessor = { val aggBufferAttributes = mutable.Buffer.empty[AttributeReference] val initialValues = mutable.Buffer.empty[Expression] @@ -938,13 +938,13 @@ private[execution] object AggregateProcessor { // Create the projections. val initialProjection = newMutableProjection( initialValues, - partitionSize.toSeq)() + partitionSize.toSeq) val updateProjection = newMutableProjection( updateExpressions, - aggBufferAttributes ++ inputAttributes)() + aggBufferAttributes ++ inputAttributes) val evaluateProjection = newMutableProjection( evaluateExpressions, - aggBufferAttributes)() + aggBufferAttributes) // Create the processor new AggregateProcessor( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 042c731901..81aacb437b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -39,7 +39,7 @@ abstract class AggregationIterator( aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection)) + newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection) extends Iterator[UnsafeRow] with Logging { /////////////////////////////////////////////////////////////////////////// @@ -139,7 +139,7 @@ abstract class AggregationIterator( // no-op expressions which are ignored during projection code-generation. case i: ImperativeAggregate => Seq.fill(i.aggBufferAttributes.length)(NoOp) } - newMutableProjection(initExpressions, Nil)() + newMutableProjection(initExpressions, Nil) } // All imperative AggregateFunctions. @@ -175,7 +175,7 @@ abstract class AggregationIterator( // This projection is used to merge buffer values for all expression-based aggregates. val aggregationBufferSchema = functions.flatMap(_.aggBufferAttributes) val updateProjection = - newMutableProjection(mergeExpressions, aggregationBufferSchema ++ inputAttributes)() + newMutableProjection(mergeExpressions, aggregationBufferSchema ++ inputAttributes) (currentBuffer: MutableRow, row: InternalRow) => { // Process all expression-based aggregate functions. @@ -211,7 +211,7 @@ abstract class AggregationIterator( case agg: AggregateFunction => NoOp } val aggregateResult = new SpecificMutableRow(aggregateAttributes.map(_.dataType)) - val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)() + val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes) expressionAggEvalProjection.target(aggregateResult) val resultProjection = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index de1491d357..c35d781d3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -34,7 +34,7 @@ class SortBasedAggregationIterator( aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, numOutputRows: LongSQLMetric) extends AggregationIterator( groupingExpressions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 89977f9e08..d4cef8f310 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -335,7 +335,7 @@ case class TungstenAggregate( val mergeProjection = newMutableProjection( mergeExpr, aggregateBufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes), - subexpressionEliminationEnabled)() + subexpressionEliminationEnabled) val joinedRow = new JoinedRow() var currentKey: UnsafeRow = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 09384a482d..c368726610 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -82,7 +82,7 @@ class TungstenAggregationIterator( aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, originalInputAttributes: Seq[Attribute], inputIter: Iterator[InternalRow], testFallbackStartsAt: Option[(Int, Int)], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index f5776e7b8d..4ceb710f4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -361,7 +361,7 @@ private[sql] case class ScalaUDAF( val inputAttributes = childrenSchema.toAttributes log.debug( s"Creating MutableProj: $children, inputSchema: $inputAttributes.") - GenerateMutableProjection.generate(children, inputAttributes)() + GenerateMutableProjection.generate(children, inputAttributes) } private[this] lazy val inputToScalaConverters: Any => Any = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala index c9ab40a0a9..c49f173ad6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala @@ -86,7 +86,7 @@ case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], c } }.toArray }.toArray - val projection = newMutableProjection(allInputs, child.output)() + val projection = newMutableProjection(allInputs, child.output) val schema = StructType(dataTypes.map(dt => StructField("", dt))) val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython) |