aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-11-18 16:48:09 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-18 16:48:09 -0800
commit59a501359a267fbdb7689058693aa788703e54b1 (patch)
tree5d1f5d19544a170803f33399ed6eeb5a7e18b900
parent921900fd06362474f8caac675803d526a0986d70 (diff)
downloadspark-59a501359a267fbdb7689058693aa788703e54b1.tar.gz
spark-59a501359a267fbdb7689058693aa788703e54b1.tar.bz2
spark-59a501359a267fbdb7689058693aa788703e54b1.zip
[SPARK-11636][SQL] Support classes defined in the REPL with Encoders
Before this PR there were two things that would blow up if you called `df.as[MyClass]` if `MyClass` was defined in the REPL: - [x] Because `classForName` doesn't work on the munged names returned by `tpe.erasure.typeSymbol.asClass.fullName` - [x] Because we don't have anything to pass into the constructor for the `$outer` pointer. Note that this PR is just adding the infrastructure for working with inner classes in encoder and is not yet sufficient to make them work in the REPL. Currently, the implementation show in https://github.com/marmbrus/spark/commit/95cec7d413b930b36420724fafd829bef8c732ab is causing a bug that breaks code gen due to some interaction between janino and the `ExecutorClassLoader`. This will be addressed in a follow-up PR. Author: Michael Armbrust <michael@databricks.com> Closes #9602 from marmbrus/dataset-replClasses.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala81
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala26
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala42
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala42
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala7
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala19
17 files changed, 193 insertions, 82 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 38828e59a2..59ccf356f2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -35,17 +35,6 @@ object ScalaReflection extends ScalaReflection {
// class loader of the current thread.
override def mirror: universe.Mirror =
universe.runtimeMirror(Thread.currentThread().getContextClassLoader)
-}
-
-/**
- * Support for generating catalyst schemas for scala objects.
- */
-trait ScalaReflection {
- /** The universe we work in (runtime or macro) */
- val universe: scala.reflect.api.Universe
-
- /** The mirror used to access types in the universe */
- def mirror: universe.Mirror
import universe._
@@ -53,30 +42,6 @@ trait ScalaReflection {
// Since the map values can be mutable, we explicitly import scala.collection.Map at here.
import scala.collection.Map
- case class Schema(dataType: DataType, nullable: Boolean)
-
- /** Returns a Sequence of attributes for the given case class type. */
- def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
- case Schema(s: StructType, _) =>
- s.toAttributes
- }
-
- /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
- def schemaFor[T: TypeTag]: Schema =
- ScalaReflectionLock.synchronized { schemaFor(localTypeOf[T]) }
-
- /**
- * Return the Scala Type for `T` in the current classloader mirror.
- *
- * Use this method instead of the convenience method `universe.typeOf`, which
- * assumes that all types can be found in the classloader that loaded scala-reflect classes.
- * That's not necessarily the case when running using Eclipse launchers or even
- * Sbt console or test (without `fork := true`).
- *
- * @see SPARK-5281
- */
- def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe
-
/**
* Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping
* to a native type, an ObjectType is returned. Special handling is also used for Arrays including
@@ -114,7 +79,9 @@ trait ScalaReflection {
}
ObjectType(cls)
- case other => ObjectType(Utils.classForName(className))
+ case other =>
+ val clazz = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass)
+ ObjectType(clazz)
}
}
@@ -640,6 +607,48 @@ trait ScalaReflection {
}
}
}
+}
+
+/**
+ * Support for generating catalyst schemas for scala objects. Note that unlike its companion
+ * object, this trait able to work in both the runtime and the compile time (macro) universe.
+ */
+trait ScalaReflection {
+ /** The universe we work in (runtime or macro) */
+ val universe: scala.reflect.api.Universe
+
+ /** The mirror used to access types in the universe */
+ def mirror: universe.Mirror
+
+ import universe._
+
+ // The Predef.Map is scala.collection.immutable.Map.
+ // Since the map values can be mutable, we explicitly import scala.collection.Map at here.
+ import scala.collection.Map
+
+ case class Schema(dataType: DataType, nullable: Boolean)
+
+ /** Returns a Sequence of attributes for the given case class type. */
+ def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
+ case Schema(s: StructType, _) =>
+ s.toAttributes
+ }
+
+ /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
+ def schemaFor[T: TypeTag]: Schema =
+ ScalaReflectionLock.synchronized { schemaFor(localTypeOf[T]) }
+
+ /**
+ * Return the Scala Type for `T` in the current classloader mirror.
+ *
+ * Use this method instead of the convenience method `universe.typeOf`, which
+ * assumes that all types can be found in the classloader that loaded scala-reflect classes.
+ * That's not necessarily the case when running using Eclipse launchers or even
+ * Sbt console or test (without `fork := true`).
+ *
+ * @see SPARK-5281
+ */
+ def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index b977f278c5..456b595008 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -17,11 +17,13 @@
package org.apache.spark.sql.catalyst.encoders
+import java.util.concurrent.ConcurrentMap
+
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.{typeTag, TypeTag}
import org.apache.spark.util.Utils
-import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.{AnalysisException, Encoder}
import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.catalyst.expressions._
@@ -211,7 +213,9 @@ case class ExpressionEncoder[T](
* Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the
* given schema.
*/
- def resolve(schema: Seq[Attribute]): ExpressionEncoder[T] = {
+ def resolve(
+ schema: Seq[Attribute],
+ outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = {
val positionToAttribute = AttributeMap.toIndex(schema)
val unbound = fromRowExpression transform {
case b: BoundReference => positionToAttribute(b.ordinal)
@@ -219,7 +223,23 @@ case class ExpressionEncoder[T](
val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema))
val analyzedPlan = SimpleAnalyzer.execute(plan)
- copy(fromRowExpression = analyzedPlan.expressions.head.children.head)
+
+ // In order to construct instances of inner classes (for example those declared in a REPL cell),
+ // we need an instance of the outer scope. This rule substitues those outer objects into
+ // expressions that are missing them by looking up the name in the SQLContexts `outerScopes`
+ // registry.
+ copy(fromRowExpression = analyzedPlan.expressions.head.children.head transform {
+ case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass =>
+ val outer = outerScopes.get(n.cls.getDeclaringClass.getName)
+ if (outer == null) {
+ throw new AnalysisException(
+ s"Unable to generate an encoder for inner class `${n.cls.getName}` without access " +
+ s"to the scope that this class was defined in. " + "" +
+ "Try moving this class out of its parent class.")
+ }
+
+ n.copy(outerPointer = Some(Literal.fromObject(outer)))
+ })
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
new file mode 100644
index 0000000000..a753b187bc
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import java.util.concurrent.ConcurrentMap
+
+import com.google.common.collect.MapMaker
+
+object OuterScopes {
+ @transient
+ lazy val outerScopes: ConcurrentMap[String, AnyRef] =
+ new MapMaker().weakValues().makeMap()
+
+ /**
+ * Adds a new outer scope to this context that can be used when instantiating an `inner class`
+ * during deserialialization. Inner classes are created when a case class is defined in the
+ * Spark REPL and registering the outer scope that this class was defined in allows us to create
+ * new instances on the spark executors. In normal use, users should not need to call this
+ * function.
+ *
+ * Warning: this function operates on the assumption that there is only ever one instance of any
+ * given wrapper class.
+ */
+ def addOuterScope(outer: AnyRef): Unit = {
+ outerScopes.putIfAbsent(outer.getClass.getName, outer)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
index 55c4ee11b2..2914c6ee79 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
@@ -31,6 +31,7 @@ import scala.reflect.ClassTag
object ProductEncoder {
import ScalaReflection.universe._
+ import ScalaReflection.mirror
import ScalaReflection.localTypeOf
import ScalaReflection.dataTypeFor
import ScalaReflection.Schema
@@ -420,8 +421,7 @@ object ProductEncoder {
}
}
- val className: String = t.erasure.typeSymbol.asClass.fullName
- val cls = Utils.classForName(className)
+ val cls = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass)
val arguments = params.head.zipWithIndex.map { case (p, i) =>
val fieldName = p.name.toString
@@ -429,7 +429,7 @@ object ProductEncoder {
val dataType = schemaFor(fieldType).dataType
// For tuples, we based grab the inner fields by ordinal instead of name.
- if (className startsWith "scala.Tuple") {
+ if (cls.getName startsWith "scala.Tuple") {
constructorFor(fieldType, Some(addToPathOrdinal(i, dataType)))
} else {
constructorFor(fieldType, Some(addToPath(fieldName)))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
index d51a8dede7..a31574c251 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
@@ -34,7 +34,7 @@ trait CodegenFallback extends Expression {
val objectTerm = ctx.freshName("obj")
s"""
/* expression: ${this} */
- Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW});
+ java.lang.Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW});
boolean ${ev.isNull} = $objectTerm == null;
${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)};
if (!${ev.isNull}) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index 4b66069b5f..40189f0877 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -82,7 +82,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates)
val code = s"""
- public Object generate($exprType[] expr) {
+ public java.lang.Object generate($exprType[] expr) {
return new SpecificMutableProjection(expr);
}
@@ -109,7 +109,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
return (InternalRow) mutableRow;
}
- public Object apply(Object _i) {
+ public java.lang.Object apply(java.lang.Object _i) {
InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i;
$allProjections
// copy all the results into MutableRow
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index c0d313b2e1..f229f2000d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -167,7 +167,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
${initMutableStates(ctx)}
}
- public Object apply(Object r) {
+ public java.lang.Object apply(java.lang.Object r) {
// GenerateProjection does not work with UnsafeRows.
assert(!(r instanceof ${classOf[UnsafeRow].getName}));
return new SpecificRow((InternalRow) r);
@@ -186,14 +186,14 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
public void setNullAt(int i) { nullBits[i] = true; }
public boolean isNullAt(int i) { return nullBits[i]; }
- public Object genericGet(int i) {
+ public java.lang.Object genericGet(int i) {
if (isNullAt(i)) return null;
switch (i) {
$getCases
}
return null;
}
- public void update(int i, Object value) {
+ public void update(int i, java.lang.Object value) {
if (value == null) {
setNullAt(i);
return;
@@ -212,7 +212,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
return result;
}
- public boolean equals(Object other) {
+ public boolean equals(java.lang.Object other) {
if (other instanceof SpecificRow) {
SpecificRow row = (SpecificRow) other;
$columnChecks
@@ -222,7 +222,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
}
public InternalRow copy() {
- Object[] arr = new Object[${expressions.length}];
+ java.lang.Object[] arr = new java.lang.Object[${expressions.length}];
${copyColumns}
return new ${classOf[GenericInternalRow].getName}(arr);
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index f0ed8645d9..b7926bda3d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -148,7 +148,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
}
val allExpressions = ctx.splitExpressions(ctx.INPUT_ROW, expressionCodes)
val code = s"""
- public Object generate($exprType[] expr) {
+ public java.lang.Object generate($exprType[] expr) {
return new SpecificSafeProjection(expr);
}
@@ -165,7 +165,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
${initMutableStates(ctx)}
}
- public Object apply(Object _i) {
+ public java.lang.Object apply(java.lang.Object _i) {
InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i;
$allExpressions
return mutableRow;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 4c17d02a23..7b6c9373eb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -324,7 +324,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val eval = createCode(ctx, expressions, subexpressionEliminationEnabled)
val code = s"""
- public Object generate($exprType[] exprs) {
+ public java.lang.Object generate($exprType[] exprs) {
return new SpecificUnsafeProjection(exprs);
}
@@ -342,7 +342,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
}
// Scala.Function1 need this
- public Object apply(Object row) {
+ public java.lang.Object apply(java.lang.Object row) {
return apply((InternalRow) row);
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
index da91ff2953..da602d9b4b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
@@ -159,7 +159,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
// ------------------------ Finally, put everything together --------------------------- //
val code = s"""
- |public Object generate($exprType[] exprs) {
+ |public java.lang.Object generate($exprType[] exprs) {
| return new SpecificUnsafeRowJoiner();
|}
|
@@ -176,9 +176,9 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
| buf = new byte[sizeInBytes];
| }
|
- | final Object obj1 = row1.getBaseObject();
+ | final java.lang.Object obj1 = row1.getBaseObject();
| final long offset1 = row1.getBaseOffset();
- | final Object obj2 = row2.getBaseObject();
+ | final java.lang.Object obj2 = row2.getBaseObject();
| final long offset2 = row2.getBaseOffset();
|
| $copyBitset
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 455fa2427c..e34fd49be8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -48,6 +48,12 @@ object Literal {
throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v)
}
+ /**
+ * Constructs a [[Literal]] of [[ObjectType]], for example when you need to pass an object
+ * into code generation.
+ */
+ def fromObject(obj: AnyRef): Literal = new Literal(obj, ObjectType(obj.getClass))
+
def create(v: Any, dataType: DataType): Literal = {
Literal(CatalystTypeConverters.convertToCatalyst(v), dataType)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index acf0da2400..f865a9408e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -24,6 +24,7 @@ import org.apache.spark.SparkConf
import org.apache.spark.serializer._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
+import org.apache.spark.sql.catalyst.encoders.ProductEncoder
import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation}
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.catalyst.InternalRow
@@ -178,6 +179,15 @@ case class Invoke(
}
}
+object NewInstance {
+ def apply(
+ cls: Class[_],
+ arguments: Seq[Expression],
+ propagateNull: Boolean = false,
+ dataType: DataType): NewInstance =
+ new NewInstance(cls, arguments, propagateNull, dataType, None)
+}
+
/**
* Constructs a new instance of the given class, using the result of evaluating the specified
* expressions as arguments.
@@ -189,12 +199,15 @@ case class Invoke(
* @param dataType The type of object being constructed, as a Spark SQL datatype. This allows you
* to manually specify the type when the object in question is a valid internal
* representation (i.e. ArrayData) instead of an object.
+ * @param outerPointer If the object being constructed is an inner class the outerPointer must
+ * for the containing class must be specified.
*/
case class NewInstance(
cls: Class[_],
arguments: Seq[Expression],
- propagateNull: Boolean = true,
- dataType: DataType) extends Expression {
+ propagateNull: Boolean,
+ dataType: DataType,
+ outerPointer: Option[Literal]) extends Expression {
private val className = cls.getName
override def nullable: Boolean = propagateNull
@@ -209,30 +222,43 @@ case class NewInstance(
val argGen = arguments.map(_.gen(ctx))
val argString = argGen.map(_.value).mkString(", ")
+ val outer = outerPointer.map(_.gen(ctx))
+
+ val setup =
+ s"""
+ ${argGen.map(_.code).mkString("\n")}
+ ${outer.map(_.code.mkString("")).getOrElse("")}
+ """.stripMargin
+
+ val constructorCall = outer.map { gen =>
+ s"""${gen.value}.new ${cls.getSimpleName}($argString)"""
+ }.getOrElse {
+ s"new $className($argString)"
+ }
+
if (propagateNull) {
val objNullCheck = if (ctx.defaultValue(dataType) == "null") {
s"${ev.isNull} = ${ev.value} == null;"
} else {
""
}
-
val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})"
+
s"""
- ${argGen.map(_.code).mkString("\n")}
+ $setup
boolean ${ev.isNull} = true;
$javaType ${ev.value} = ${ctx.defaultValue(dataType)};
-
if ($argsNonNull) {
- ${ev.value} = new $className($argString);
+ ${ev.value} = $constructorCall;
${ev.isNull} = false;
}
"""
} else {
s"""
- ${argGen.map(_.code).mkString("\n")}
+ $setup
- $javaType ${ev.value} = new $className($argString);
+ $javaType ${ev.value} = $constructorCall;
final boolean ${ev.isNull} = ${ev.value} == null;
"""
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 9fe64b4cf1..cde0364f3d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -18,6 +18,9 @@
package org.apache.spark.sql.catalyst.encoders
import java.util.Arrays
+import java.util.concurrent.ConcurrentMap
+
+import com.google.common.collect.MapMaker
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.AttributeReference
@@ -25,6 +28,8 @@ import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types.ArrayType
abstract class ExpressionEncoderSuite extends SparkFunSuite {
+ val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap()
+
protected def encodeDecodeTest[T](
input: T,
encoder: ExpressionEncoder[T],
@@ -32,7 +37,7 @@ abstract class ExpressionEncoderSuite extends SparkFunSuite {
test(s"encode/decode for $testName: $input") {
val row = encoder.toRow(input)
val schema = encoder.schema.toAttributes
- val boundEncoder = encoder.resolve(schema).bind(schema)
+ val boundEncoder = encoder.resolve(schema, outers).bind(schema)
val convertedBack = try boundEncoder.fromRow(row) catch {
case e: Exception =>
fail(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
index bc539d62c5..1798514c5c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
@@ -53,6 +53,10 @@ case class RepeatedData(
case class SpecificCollection(l: List[Int])
class ProductEncoderSuite extends ExpressionEncoderSuite {
+ outers.put(getClass.getName, this)
+
+ case class InnerClass(i: Int)
+ productTest(InnerClass(1))
productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index b644f6ad30..bdcdc5d47c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -74,7 +74,7 @@ class Dataset[T] private[sql](
/** The encoder for this [[Dataset]] that has been resolved to its output schema. */
private[sql] val resolvedTEncoder: ExpressionEncoder[T] =
- unresolvedTEncoder.resolve(queryExecution.analyzed.output)
+ unresolvedTEncoder.resolve(queryExecution.analyzed.output, OuterScopes.outerScopes)
private implicit def classTag = resolvedTEncoder.clsTag
@@ -375,7 +375,7 @@ class Dataset[T] private[sql](
sqlContext,
Project(
c1.withInputType(
- resolvedTEncoder,
+ resolvedTEncoder.bind(queryExecution.analyzed.output),
queryExecution.analyzed.output).named :: Nil,
logicalPlan))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index 3f84e22a10..7e5acbe851 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -21,7 +21,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.function._
-import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor}
+import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor, OuterScopes}
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.QueryExecution
@@ -52,8 +52,10 @@ class GroupedDataset[K, T] private[sql](
private implicit val unresolvedKEncoder = encoderFor(kEncoder)
private implicit val unresolvedTEncoder = encoderFor(tEncoder)
- private val resolvedKEncoder = unresolvedKEncoder.resolve(groupingAttributes)
- private val resolvedTEncoder = unresolvedTEncoder.resolve(dataAttributes)
+ private val resolvedKEncoder =
+ unresolvedKEncoder.resolve(groupingAttributes, OuterScopes.outerScopes)
+ private val resolvedTEncoder =
+ unresolvedTEncoder.resolve(dataAttributes, OuterScopes.outerScopes)
private def logicalPlan = queryExecution.analyzed
private def sqlContext = queryExecution.sqlContext
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index 3f2775896b..6ce41aaf01 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -52,8 +52,8 @@ object TypedAggregateExpression {
*/
case class TypedAggregateExpression(
aggregator: Aggregator[Any, Any, Any],
- aEncoder: Option[ExpressionEncoder[Any]],
- bEncoder: ExpressionEncoder[Any],
+ aEncoder: Option[ExpressionEncoder[Any]], // Should be bound.
+ bEncoder: ExpressionEncoder[Any], // Should be bound.
cEncoder: ExpressionEncoder[Any],
children: Seq[Attribute],
mutableAggBufferOffset: Int,
@@ -92,9 +92,6 @@ case class TypedAggregateExpression(
// We let the dataset do the binding for us.
lazy val boundA = aEncoder.get
- val bAttributes = bEncoder.schema.toAttributes
- lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes)
-
private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = {
// todo: need a more neat way to assign the value.
var i = 0
@@ -114,24 +111,24 @@ case class TypedAggregateExpression(
override def update(buffer: MutableRow, input: InternalRow): Unit = {
val inputA = boundA.fromRow(input)
- val currentB = boundB.shift(mutableAggBufferOffset).fromRow(buffer)
+ val currentB = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer)
val merged = aggregator.reduce(currentB, inputA)
- val returned = boundB.toRow(merged)
+ val returned = bEncoder.toRow(merged)
updateBuffer(buffer, returned)
}
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
- val b1 = boundB.shift(mutableAggBufferOffset).fromRow(buffer1)
- val b2 = boundB.shift(inputAggBufferOffset).fromRow(buffer2)
+ val b1 = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer1)
+ val b2 = bEncoder.shift(inputAggBufferOffset).fromRow(buffer2)
val merged = aggregator.merge(b1, b2)
- val returned = boundB.toRow(merged)
+ val returned = bEncoder.toRow(merged)
updateBuffer(buffer1, returned)
}
override def eval(buffer: InternalRow): Any = {
- val b = boundB.shift(mutableAggBufferOffset).fromRow(buffer)
+ val b = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer)
val result = cEncoder.toRow(aggregator.finish(b))
dataType match {
case _: StructType => result