aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-11-18 16:48:09 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-18 16:48:09 -0800
commit59a501359a267fbdb7689058693aa788703e54b1 (patch)
tree5d1f5d19544a170803f33399ed6eeb5a7e18b900 /sql/core
parent921900fd06362474f8caac675803d526a0986d70 (diff)
downloadspark-59a501359a267fbdb7689058693aa788703e54b1.tar.gz
spark-59a501359a267fbdb7689058693aa788703e54b1.tar.bz2
spark-59a501359a267fbdb7689058693aa788703e54b1.zip
[SPARK-11636][SQL] Support classes defined in the REPL with Encoders
Before this PR there were two things that would blow up if you called `df.as[MyClass]` if `MyClass` was defined in the REPL: - [x] Because `classForName` doesn't work on the munged names returned by `tpe.erasure.typeSymbol.asClass.fullName` - [x] Because we don't have anything to pass into the constructor for the `$outer` pointer. Note that this PR is just adding the infrastructure for working with inner classes in encoder and is not yet sufficient to make them work in the REPL. Currently, the implementation show in https://github.com/marmbrus/spark/commit/95cec7d413b930b36420724fafd829bef8c732ab is causing a bug that breaks code gen due to some interaction between janino and the `ExecutorClassLoader`. This will be addressed in a follow-up PR. Author: Michael Armbrust <michael@databricks.com> Closes #9602 from marmbrus/dataset-replClasses.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala19
3 files changed, 15 insertions, 16 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index b644f6ad30..bdcdc5d47c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -74,7 +74,7 @@ class Dataset[T] private[sql](
/** The encoder for this [[Dataset]] that has been resolved to its output schema. */
private[sql] val resolvedTEncoder: ExpressionEncoder[T] =
- unresolvedTEncoder.resolve(queryExecution.analyzed.output)
+ unresolvedTEncoder.resolve(queryExecution.analyzed.output, OuterScopes.outerScopes)
private implicit def classTag = resolvedTEncoder.clsTag
@@ -375,7 +375,7 @@ class Dataset[T] private[sql](
sqlContext,
Project(
c1.withInputType(
- resolvedTEncoder,
+ resolvedTEncoder.bind(queryExecution.analyzed.output),
queryExecution.analyzed.output).named :: Nil,
logicalPlan))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index 3f84e22a10..7e5acbe851 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -21,7 +21,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.function._
-import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor}
+import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor, OuterScopes}
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.QueryExecution
@@ -52,8 +52,10 @@ class GroupedDataset[K, T] private[sql](
private implicit val unresolvedKEncoder = encoderFor(kEncoder)
private implicit val unresolvedTEncoder = encoderFor(tEncoder)
- private val resolvedKEncoder = unresolvedKEncoder.resolve(groupingAttributes)
- private val resolvedTEncoder = unresolvedTEncoder.resolve(dataAttributes)
+ private val resolvedKEncoder =
+ unresolvedKEncoder.resolve(groupingAttributes, OuterScopes.outerScopes)
+ private val resolvedTEncoder =
+ unresolvedTEncoder.resolve(dataAttributes, OuterScopes.outerScopes)
private def logicalPlan = queryExecution.analyzed
private def sqlContext = queryExecution.sqlContext
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index 3f2775896b..6ce41aaf01 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -52,8 +52,8 @@ object TypedAggregateExpression {
*/
case class TypedAggregateExpression(
aggregator: Aggregator[Any, Any, Any],
- aEncoder: Option[ExpressionEncoder[Any]],
- bEncoder: ExpressionEncoder[Any],
+ aEncoder: Option[ExpressionEncoder[Any]], // Should be bound.
+ bEncoder: ExpressionEncoder[Any], // Should be bound.
cEncoder: ExpressionEncoder[Any],
children: Seq[Attribute],
mutableAggBufferOffset: Int,
@@ -92,9 +92,6 @@ case class TypedAggregateExpression(
// We let the dataset do the binding for us.
lazy val boundA = aEncoder.get
- val bAttributes = bEncoder.schema.toAttributes
- lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes)
-
private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = {
// todo: need a more neat way to assign the value.
var i = 0
@@ -114,24 +111,24 @@ case class TypedAggregateExpression(
override def update(buffer: MutableRow, input: InternalRow): Unit = {
val inputA = boundA.fromRow(input)
- val currentB = boundB.shift(mutableAggBufferOffset).fromRow(buffer)
+ val currentB = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer)
val merged = aggregator.reduce(currentB, inputA)
- val returned = boundB.toRow(merged)
+ val returned = bEncoder.toRow(merged)
updateBuffer(buffer, returned)
}
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
- val b1 = boundB.shift(mutableAggBufferOffset).fromRow(buffer1)
- val b2 = boundB.shift(inputAggBufferOffset).fromRow(buffer2)
+ val b1 = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer1)
+ val b2 = bEncoder.shift(inputAggBufferOffset).fromRow(buffer2)
val merged = aggregator.merge(b1, b2)
- val returned = boundB.toRow(merged)
+ val returned = bEncoder.toRow(merged)
updateBuffer(buffer1, returned)
}
override def eval(buffer: InternalRow): Any = {
- val b = boundB.shift(mutableAggBufferOffset).fromRow(buffer)
+ val b = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer)
val result = cEncoder.toRow(aggregator.finish(b))
dataType match {
case _: StructType => result