aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala19
3 files changed, 15 insertions, 16 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index b644f6ad30..bdcdc5d47c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -74,7 +74,7 @@ class Dataset[T] private[sql](
/** The encoder for this [[Dataset]] that has been resolved to its output schema. */
private[sql] val resolvedTEncoder: ExpressionEncoder[T] =
- unresolvedTEncoder.resolve(queryExecution.analyzed.output)
+ unresolvedTEncoder.resolve(queryExecution.analyzed.output, OuterScopes.outerScopes)
private implicit def classTag = resolvedTEncoder.clsTag
@@ -375,7 +375,7 @@ class Dataset[T] private[sql](
sqlContext,
Project(
c1.withInputType(
- resolvedTEncoder,
+ resolvedTEncoder.bind(queryExecution.analyzed.output),
queryExecution.analyzed.output).named :: Nil,
logicalPlan))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index 3f84e22a10..7e5acbe851 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -21,7 +21,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.function._
-import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor}
+import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor, OuterScopes}
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.QueryExecution
@@ -52,8 +52,10 @@ class GroupedDataset[K, T] private[sql](
private implicit val unresolvedKEncoder = encoderFor(kEncoder)
private implicit val unresolvedTEncoder = encoderFor(tEncoder)
- private val resolvedKEncoder = unresolvedKEncoder.resolve(groupingAttributes)
- private val resolvedTEncoder = unresolvedTEncoder.resolve(dataAttributes)
+ private val resolvedKEncoder =
+ unresolvedKEncoder.resolve(groupingAttributes, OuterScopes.outerScopes)
+ private val resolvedTEncoder =
+ unresolvedTEncoder.resolve(dataAttributes, OuterScopes.outerScopes)
private def logicalPlan = queryExecution.analyzed
private def sqlContext = queryExecution.sqlContext
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index 3f2775896b..6ce41aaf01 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -52,8 +52,8 @@ object TypedAggregateExpression {
*/
case class TypedAggregateExpression(
aggregator: Aggregator[Any, Any, Any],
- aEncoder: Option[ExpressionEncoder[Any]],
- bEncoder: ExpressionEncoder[Any],
+ aEncoder: Option[ExpressionEncoder[Any]], // Should be bound.
+ bEncoder: ExpressionEncoder[Any], // Should be bound.
cEncoder: ExpressionEncoder[Any],
children: Seq[Attribute],
mutableAggBufferOffset: Int,
@@ -92,9 +92,6 @@ case class TypedAggregateExpression(
// We let the dataset do the binding for us.
lazy val boundA = aEncoder.get
- val bAttributes = bEncoder.schema.toAttributes
- lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes)
-
private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = {
// todo: need a more neat way to assign the value.
var i = 0
@@ -114,24 +111,24 @@ case class TypedAggregateExpression(
override def update(buffer: MutableRow, input: InternalRow): Unit = {
val inputA = boundA.fromRow(input)
- val currentB = boundB.shift(mutableAggBufferOffset).fromRow(buffer)
+ val currentB = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer)
val merged = aggregator.reduce(currentB, inputA)
- val returned = boundB.toRow(merged)
+ val returned = bEncoder.toRow(merged)
updateBuffer(buffer, returned)
}
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
- val b1 = boundB.shift(mutableAggBufferOffset).fromRow(buffer1)
- val b2 = boundB.shift(inputAggBufferOffset).fromRow(buffer2)
+ val b1 = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer1)
+ val b2 = bEncoder.shift(inputAggBufferOffset).fromRow(buffer2)
val merged = aggregator.merge(b1, b2)
- val returned = boundB.toRow(merged)
+ val returned = bEncoder.toRow(merged)
updateBuffer(buffer1, returned)
}
override def eval(buffer: InternalRow): Any = {
- val b = boundB.shift(mutableAggBufferOffset).fromRow(buffer)
+ val b = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer)
val result = cEncoder.toRow(aggregator.finish(b))
dataType match {
case _: StructType => result