diff options
Diffstat (limited to 'sql/core')
3 files changed, 15 insertions, 16 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b644f6ad30..bdcdc5d47c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -74,7 +74,7 @@ class Dataset[T] private[sql]( /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ private[sql] val resolvedTEncoder: ExpressionEncoder[T] = - unresolvedTEncoder.resolve(queryExecution.analyzed.output) + unresolvedTEncoder.resolve(queryExecution.analyzed.output, OuterScopes.outerScopes) private implicit def classTag = resolvedTEncoder.clsTag @@ -375,7 +375,7 @@ class Dataset[T] private[sql]( sqlContext, Project( c1.withInputType( - resolvedTEncoder, + resolvedTEncoder.bind(queryExecution.analyzed.output), queryExecution.analyzed.output).named :: Nil, logicalPlan)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 3f84e22a10..7e5acbe851 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ -import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor} +import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor, OuterScopes} import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution @@ -52,8 +52,10 @@ class GroupedDataset[K, T] private[sql]( private implicit val unresolvedKEncoder = encoderFor(kEncoder) private implicit val unresolvedTEncoder = encoderFor(tEncoder) - private val resolvedKEncoder = unresolvedKEncoder.resolve(groupingAttributes) - private val resolvedTEncoder = unresolvedTEncoder.resolve(dataAttributes) + private val resolvedKEncoder = + unresolvedKEncoder.resolve(groupingAttributes, OuterScopes.outerScopes) + private val resolvedTEncoder = + unresolvedTEncoder.resolve(dataAttributes, OuterScopes.outerScopes) private def logicalPlan = queryExecution.analyzed private def sqlContext = queryExecution.sqlContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 3f2775896b..6ce41aaf01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -52,8 +52,8 @@ object TypedAggregateExpression { */ case class TypedAggregateExpression( aggregator: Aggregator[Any, Any, Any], - aEncoder: Option[ExpressionEncoder[Any]], - bEncoder: ExpressionEncoder[Any], + aEncoder: Option[ExpressionEncoder[Any]], // Should be bound. + bEncoder: ExpressionEncoder[Any], // Should be bound. cEncoder: ExpressionEncoder[Any], children: Seq[Attribute], mutableAggBufferOffset: Int, @@ -92,9 +92,6 @@ case class TypedAggregateExpression( // We let the dataset do the binding for us. lazy val boundA = aEncoder.get - val bAttributes = bEncoder.schema.toAttributes - lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes) - private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = { // todo: need a more neat way to assign the value. var i = 0 @@ -114,24 +111,24 @@ case class TypedAggregateExpression( override def update(buffer: MutableRow, input: InternalRow): Unit = { val inputA = boundA.fromRow(input) - val currentB = boundB.shift(mutableAggBufferOffset).fromRow(buffer) + val currentB = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer) val merged = aggregator.reduce(currentB, inputA) - val returned = boundB.toRow(merged) + val returned = bEncoder.toRow(merged) updateBuffer(buffer, returned) } override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - val b1 = boundB.shift(mutableAggBufferOffset).fromRow(buffer1) - val b2 = boundB.shift(inputAggBufferOffset).fromRow(buffer2) + val b1 = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer1) + val b2 = bEncoder.shift(inputAggBufferOffset).fromRow(buffer2) val merged = aggregator.merge(b1, b2) - val returned = boundB.toRow(merged) + val returned = bEncoder.toRow(merged) updateBuffer(buffer1, returned) } override def eval(buffer: InternalRow): Any = { - val b = boundB.shift(mutableAggBufferOffset).fromRow(buffer) + val b = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer) val result = cEncoder.toRow(aggregator.finish(b)) dataType match { case _: StructType => result |