aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala82
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala50
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala68
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala8
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala31
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala64
12 files changed, 230 insertions, 104 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index b59eb12419..cb228cf52b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf}
+import org.apache.spark.sql.catalyst.encoders.OuterScopes
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans._
@@ -457,25 +458,34 @@ class Analyzer(
// When resolve `SortOrder`s in Sort based on child, don't report errors as
// we still have chance to resolve it based on its descendants
case s @ Sort(ordering, global, child) if child.resolved && !s.resolved =>
- val newOrdering = resolveSortOrders(ordering, child, throws = false)
+ val newOrdering =
+ ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder])
Sort(newOrdering, global, child)
// A special case for Generate, because the output of Generate should not be resolved by
// ResolveReferences. Attributes in the output will be resolved by ResolveGenerate.
case g @ Generate(generator, join, outer, qualifier, output, child)
if child.resolved && !generator.resolved =>
- val newG = generator transformUp {
- case u @ UnresolvedAttribute(nameParts) =>
- withPosition(u) { child.resolve(nameParts, resolver).getOrElse(u) }
- case UnresolvedExtractValue(child, fieldExpr) =>
- ExtractValue(child, fieldExpr, resolver)
- }
+ val newG = resolveExpression(generator, child, throws = true)
if (newG.fastEquals(generator)) {
g
} else {
Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child)
}
+ // A special case for ObjectOperator, because the deserializer expressions in ObjectOperator
+ // should be resolved by their corresponding attributes instead of children's output.
+ case o: ObjectOperator if containsUnresolvedDeserializer(o.deserializers.map(_._1)) =>
+ val deserializerToAttributes = o.deserializers.map {
+ case (deserializer, attributes) => new TreeNodeRef(deserializer) -> attributes
+ }.toMap
+
+ o.transformExpressions {
+ case expr => deserializerToAttributes.get(new TreeNodeRef(expr)).map { attributes =>
+ resolveDeserializer(expr, attributes)
+ }.getOrElse(expr)
+ }
+
case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
q transformExpressionsUp {
@@ -490,6 +500,32 @@ class Analyzer(
}
}
+ private def containsUnresolvedDeserializer(exprs: Seq[Expression]): Boolean = {
+ exprs.exists { expr =>
+ !expr.resolved || expr.find(_.isInstanceOf[BoundReference]).isDefined
+ }
+ }
+
+ def resolveDeserializer(
+ deserializer: Expression,
+ attributes: Seq[Attribute]): Expression = {
+ val unbound = deserializer transform {
+ case b: BoundReference => attributes(b.ordinal)
+ }
+
+ resolveExpression(unbound, LocalRelation(attributes), throws = true) transform {
+ case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass =>
+ val outer = OuterScopes.outerScopes.get(n.cls.getDeclaringClass.getName)
+ if (outer == null) {
+ throw new AnalysisException(
+ s"Unable to generate an encoder for inner class `${n.cls.getName}` without " +
+ "access to the scope that this class was defined in.\n" +
+ "Try moving this class out of its parent class.")
+ }
+ n.copy(outerPointer = Some(Literal.fromObject(outer)))
+ }
+ }
+
def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = {
expressions.map {
case a: Alias => Alias(a.child, a.name)()
@@ -508,23 +544,20 @@ class Analyzer(
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
}
- private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = {
- ordering.map { order =>
- // Resolve SortOrder in one round.
- // If throws == false or the desired attribute doesn't exist
- // (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one.
- // Else, throw exception.
- try {
- val newOrder = order transformUp {
- case u @ UnresolvedAttribute(nameParts) =>
- plan.resolve(nameParts, resolver).getOrElse(u)
- case UnresolvedExtractValue(child, fieldName) if child.resolved =>
- ExtractValue(child, fieldName, resolver)
- }
- newOrder.asInstanceOf[SortOrder]
- } catch {
- case a: AnalysisException if !throws => order
+ private def resolveExpression(expr: Expression, plan: LogicalPlan, throws: Boolean = false) = {
+ // Resolve expression in one round.
+ // If throws == false or the desired attribute doesn't exist
+ // (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one.
+ // Else, throw exception.
+ try {
+ expr transformUp {
+ case u @ UnresolvedAttribute(nameParts) =>
+ withPosition(u) { plan.resolve(nameParts, resolver).getOrElse(u) }
+ case UnresolvedExtractValue(child, fieldName) if child.resolved =>
+ ExtractValue(child, fieldName, resolver)
}
+ } catch {
+ case a: AnalysisException if !throws => expr
}
}
@@ -619,7 +652,8 @@ class Analyzer(
ordering: Seq[SortOrder],
plan: LogicalPlan,
child: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
- val newOrdering = resolveSortOrders(ordering, child, throws = false)
+ val newOrdering =
+ ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder])
// Construct a set that contains all of the attributes that we need to evaluate the
// ordering.
val requiredAttributes = AttributeSet(newOrdering).filter(_.resolved)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 64832dc114..58f6d0eb9e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -50,7 +50,7 @@ object ExpressionEncoder {
val cls = mirror.runtimeClass(typeTag[T].tpe)
val flat = !classOf[Product].isAssignableFrom(cls)
- val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true)
+ val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = false)
val toRowExpression = ScalaReflection.extractorsFor[T](inputObject)
val fromRowExpression = ScalaReflection.constructorFor[T]
@@ -257,12 +257,10 @@ case class ExpressionEncoder[T](
}
/**
- * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the
- * given schema.
+ * Validates `fromRowExpression` to make sure it can be resolved by given schema, and produce
+ * friendly error messages to explain why it fails to resolve if there is something wrong.
*/
- def resolve(
- schema: Seq[Attribute],
- outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = {
+ def validate(schema: Seq[Attribute]): Unit = {
def fail(st: StructType, maxOrdinal: Int): Unit = {
throw new AnalysisException(s"Try to map ${st.simpleString} to Tuple${maxOrdinal + 1}, " +
"but failed as the number of fields does not line up.\n" +
@@ -270,6 +268,8 @@ case class ExpressionEncoder[T](
" - Target schema: " + this.schema.simpleString)
}
+ // If this is a tuple encoder or tupled encoder, which means its leaf nodes are all
+ // `BoundReference`, make sure their ordinals are all valid.
var maxOrdinal = -1
fromRowExpression.foreach {
case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal
@@ -279,6 +279,10 @@ case class ExpressionEncoder[T](
fail(StructType.fromAttributes(schema), maxOrdinal)
}
+ // If we have nested tuple, the `fromRowExpression` will contains `GetStructField` instead of
+ // `UnresolvedExtractValue`, so we need to check if their ordinals are all valid.
+ // Note that, `BoundReference` contains the expected type, but here we need the actual type, so
+ // we unbound it by the given `schema` and propagate the actual type to `GetStructField`.
val unbound = fromRowExpression transform {
case b: BoundReference => schema(b.ordinal)
}
@@ -299,28 +303,24 @@ case class ExpressionEncoder[T](
fail(schema, maxOrdinal)
}
}
+ }
- val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema))
+ /**
+ * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the
+ * given schema.
+ */
+ def resolve(
+ schema: Seq[Attribute],
+ outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = {
+ val deserializer = SimpleAnalyzer.ResolveReferences.resolveDeserializer(
+ fromRowExpression, schema)
+
+ // Make a fake plan to wrap the deserializer, so that we can go though the whole analyzer, check
+ // analysis, go through optimizer, etc.
+ val plan = Project(Alias(deserializer, "")() :: Nil, LocalRelation(schema))
val analyzedPlan = SimpleAnalyzer.execute(plan)
SimpleAnalyzer.checkAnalysis(analyzedPlan)
- val optimizedPlan = SimplifyCasts(analyzedPlan)
-
- // In order to construct instances of inner classes (for example those declared in a REPL cell),
- // we need an instance of the outer scope. This rule substitues those outer objects into
- // expressions that are missing them by looking up the name in the SQLContexts `outerScopes`
- // registry.
- copy(fromRowExpression = optimizedPlan.expressions.head.children.head transform {
- case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass =>
- val outer = outerScopes.get(n.cls.getDeclaringClass.getName)
- if (outer == null) {
- throw new AnalysisException(
- s"Unable to generate an encoder for inner class `${n.cls.getName}` without access " +
- s"to the scope that this class was defined in. " + "" +
- "Try moving this class out of its parent class.")
- }
-
- n.copy(outerPointer = Some(Literal.fromObject(outer)))
- })
+ copy(fromRowExpression = SimplifyCasts(analyzedPlan).expressions.head.children.head)
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 89d40b3b2c..d8f755a39c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -154,7 +154,7 @@ object RowEncoder {
If(
IsNull(field),
Literal.create(null, externalDataTypeFor(f.dataType)),
- constructorFor(BoundReference(i, f.dataType, f.nullable))
+ constructorFor(field)
)
}
CreateExternalRow(fields)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index a1ac930739..902e18081b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -119,10 +119,13 @@ object SamplePushDown extends Rule[LogicalPlan] {
*/
object EliminateSerialization extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case m @ MapPartitions(_, input, _, child: ObjectOperator)
- if !input.isInstanceOf[Attribute] && m.input.dataType == child.outputObject.dataType =>
+ case m @ MapPartitions(_, deserializer, _, child: ObjectOperator)
+ if !deserializer.isInstanceOf[Attribute] &&
+ deserializer.dataType == child.outputObject.dataType =>
val childWithoutSerialization = child.withObjectOutput
- m.copy(input = childWithoutSerialization.output.head, child = childWithoutSerialization)
+ m.copy(
+ deserializer = childWithoutSerialization.output.head,
+ child = childWithoutSerialization)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 7603480527..3f97662957 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types.ObjectType
+import org.apache.spark.sql.types.{ObjectType, StructType}
/**
* A trait for logical operators that apply user defined functions to domain objects.
@@ -30,6 +30,15 @@ trait ObjectOperator extends LogicalPlan {
/** The serializer that is used to produce the output of this operator. */
def serializer: Seq[NamedExpression]
+ override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+
+ /**
+ * An [[ObjectOperator]] may have one or more deserializers to convert internal rows to objects.
+ * It must also provide the attributes that are available during the resolution of each
+ * deserializer.
+ */
+ def deserializers: Seq[(Expression, Seq[Attribute])]
+
/**
* The object type that is produced by the user defined function. Note that the return type here
* is the same whether or not the operator is output serialized data.
@@ -44,13 +53,13 @@ trait ObjectOperator extends LogicalPlan {
def withObjectOutput: LogicalPlan = if (output.head.dataType.isInstanceOf[ObjectType]) {
this
} else {
- withNewSerializer(outputObject)
+ withNewSerializer(outputObject :: Nil)
}
/** Returns a copy of this operator with a different serializer. */
- def withNewSerializer(newSerializer: NamedExpression): LogicalPlan = makeCopy {
+ def withNewSerializer(newSerializer: Seq[NamedExpression]): LogicalPlan = makeCopy {
productIterator.map {
- case c if c == serializer => newSerializer :: Nil
+ case c if c == serializer => newSerializer
case other: AnyRef => other
}.toArray
}
@@ -70,15 +79,16 @@ object MapPartitions {
/**
* A relation produced by applying `func` to each partition of the `child`.
- * @param input used to extract the input to `func` from an input row.
+ *
+ * @param deserializer used to extract the input to `func` from an input row.
* @param serializer use to serialize the output of `func`.
*/
case class MapPartitions(
func: Iterator[Any] => Iterator[Any],
- input: Expression,
+ deserializer: Expression,
serializer: Seq[NamedExpression],
child: LogicalPlan) extends UnaryNode with ObjectOperator {
- override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+ override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output)
}
/** Factory for constructing new `AppendColumn` nodes. */
@@ -97,16 +107,21 @@ object AppendColumns {
/**
* A relation produced by applying `func` to each partition of the `child`, concatenating the
* resulting columns at the end of the input row.
- * @param input used to extract the input to `func` from an input row.
+ *
+ * @param deserializer used to extract the input to `func` from an input row.
* @param serializer use to serialize the output of `func`.
*/
case class AppendColumns(
func: Any => Any,
- input: Expression,
+ deserializer: Expression,
serializer: Seq[NamedExpression],
child: LogicalPlan) extends UnaryNode with ObjectOperator {
+
override def output: Seq[Attribute] = child.output ++ newColumns
+
def newColumns: Seq[Attribute] = serializer.map(_.toAttribute)
+
+ override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output)
}
/** Factory for constructing new `MapGroups` nodes. */
@@ -114,6 +129,7 @@ object MapGroups {
def apply[K : Encoder, T : Encoder, U : Encoder](
func: (K, Iterator[T]) => TraversableOnce[U],
groupingAttributes: Seq[Attribute],
+ dataAttributes: Seq[Attribute],
child: LogicalPlan): MapGroups = {
new MapGroups(
func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]],
@@ -121,6 +137,7 @@ object MapGroups {
encoderFor[T].fromRowExpression,
encoderFor[U].namedExpressions,
groupingAttributes,
+ dataAttributes,
child)
}
}
@@ -129,19 +146,22 @@ object MapGroups {
* Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`.
* Func is invoked with an object representation of the grouping key an iterator containing the
* object representation of all the rows with that key.
- * @param keyObject used to extract the key object for each group.
- * @param input used to extract the items in the iterator from an input row.
+ *
+ * @param keyDeserializer used to extract the key object for each group.
+ * @param valueDeserializer used to extract the items in the iterator from an input row.
* @param serializer use to serialize the output of `func`.
*/
case class MapGroups(
func: (Any, Iterator[Any]) => TraversableOnce[Any],
- keyObject: Expression,
- input: Expression,
+ keyDeserializer: Expression,
+ valueDeserializer: Expression,
serializer: Seq[NamedExpression],
groupingAttributes: Seq[Attribute],
+ dataAttributes: Seq[Attribute],
child: LogicalPlan) extends UnaryNode with ObjectOperator {
- def output: Seq[Attribute] = serializer.map(_.toAttribute)
+ override def deserializers: Seq[(Expression, Seq[Attribute])] =
+ Seq(keyDeserializer -> groupingAttributes, valueDeserializer -> dataAttributes)
}
/** Factory for constructing new `CoGroup` nodes. */
@@ -150,8 +170,12 @@ object CoGroup {
func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
+ leftData: Seq[Attribute],
+ rightData: Seq[Attribute],
left: LogicalPlan,
right: LogicalPlan): CoGroup = {
+ require(StructType.fromAttributes(leftGroup) == StructType.fromAttributes(rightGroup))
+
CoGroup(
func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]],
encoderFor[Key].fromRowExpression,
@@ -160,6 +184,8 @@ object CoGroup {
encoderFor[Result].namedExpressions,
leftGroup,
rightGroup,
+ leftData,
+ rightData,
left,
right)
}
@@ -171,15 +197,21 @@ object CoGroup {
*/
case class CoGroup(
func: (Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any],
- keyObject: Expression,
- leftObject: Expression,
- rightObject: Expression,
+ keyDeserializer: Expression,
+ leftDeserializer: Expression,
+ rightDeserializer: Expression,
serializer: Seq[NamedExpression],
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
+ leftAttr: Seq[Attribute],
+ rightAttr: Seq[Attribute],
left: LogicalPlan,
right: LogicalPlan) extends BinaryNode with ObjectOperator {
+
override def producedAttributes: AttributeSet = outputSet
- override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+ override def deserializers: Seq[(Expression, Seq[Attribute])] =
+ // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to resolve
+ // the `keyDeserializer` based on either of them, here we pick the left one.
+ Seq(keyDeserializer -> leftGroup, leftDeserializer -> leftAttr, rightDeserializer -> rightAttr)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
index bc36a55ae0..92a68a4dba 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
@@ -127,7 +127,7 @@ class EncoderResolutionSuite extends PlanTest {
{
val attrs = Seq('a.string, 'b.long, 'c.int)
- assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message ==
+ assert(intercept[AnalysisException](encoder.validate(attrs)).message ==
"Try to map struct<a:string,b:bigint,c:int> to Tuple2, " +
"but failed as the number of fields does not line up.\n" +
" - Input schema: struct<a:string,b:bigint,c:int>\n" +
@@ -136,7 +136,7 @@ class EncoderResolutionSuite extends PlanTest {
{
val attrs = Seq('a.string)
- assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message ==
+ assert(intercept[AnalysisException](encoder.validate(attrs)).message ==
"Try to map struct<a:string> to Tuple2, " +
"but failed as the number of fields does not line up.\n" +
" - Input schema: struct<a:string>\n" +
@@ -149,7 +149,7 @@ class EncoderResolutionSuite extends PlanTest {
{
val attrs = Seq('a.string, 'b.struct('x.long, 'y.string, 'z.int))
- assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message ==
+ assert(intercept[AnalysisException](encoder.validate(attrs)).message ==
"Try to map struct<x:bigint,y:string,z:int> to Tuple2, " +
"but failed as the number of fields does not line up.\n" +
" - Input schema: struct<a:string,b:struct<x:bigint,y:string,z:int>>\n" +
@@ -158,7 +158,7 @@ class EncoderResolutionSuite extends PlanTest {
{
val attrs = Seq('a.string, 'b.struct('x.long))
- assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message ==
+ assert(intercept[AnalysisException](encoder.validate(attrs)).message ==
"Try to map struct<x:bigint> to Tuple2, " +
"but failed as the number of fields does not line up.\n" +
" - Input schema: struct<a:string,b:struct<x:bigint>>\n" +
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 88c558d80a..e00060f9b6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -19,13 +19,10 @@ package org.apache.spark.sql.catalyst.encoders
import java.sql.{Date, Timestamp}
import java.util.Arrays
-import java.util.concurrent.ConcurrentMap
import scala.collection.mutable.ArrayBuffer
import scala.reflect.runtime.universe.TypeTag
-import com.google.common.collect.MapMaker
-
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
@@ -78,7 +75,7 @@ class JavaSerializable(val value: Int) extends Serializable {
}
class ExpressionEncoderSuite extends SparkFunSuite {
- OuterScopes.outerScopes.put(getClass.getName, this)
+ OuterScopes.addOuterScope(this)
implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index f182270a08..378763268a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -74,6 +74,7 @@ class Dataset[T] private[sql](
* same object type (that will be possibly resolved to a different schema).
*/
private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder)
+ unresolvedTEncoder.validate(logicalPlan.output)
/** The encoder for this [[Dataset]] that has been resolved to its output schema. */
private[sql] val resolvedTEncoder: ExpressionEncoder[T] =
@@ -85,7 +86,7 @@ class Dataset[T] private[sql](
*/
private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output)
- private implicit def classTag = resolvedTEncoder.clsTag
+ private implicit def classTag = unresolvedTEncoder.clsTag
private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) =
this(sqlContext, new QueryExecution(sqlContext, plan), encoder)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index b3f8284364..c0e28f2dc5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -116,6 +116,7 @@ class GroupedDataset[K, V] private[sql](
MapGroups(
f,
groupingAttributes,
+ dataAttributes,
logicalPlan))
}
@@ -310,6 +311,8 @@ class GroupedDataset[K, V] private[sql](
f,
this.groupingAttributes,
other.groupingAttributes,
+ this.dataAttributes,
+ other.dataAttributes,
this.logicalPlan,
other.logicalPlan))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 9293e55141..830bb011be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -306,11 +306,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.MapPartitions(f, in, out, planLater(child)) :: Nil
case logical.AppendColumns(f, in, out, child) =>
execution.AppendColumns(f, in, out, planLater(child)) :: Nil
- case logical.MapGroups(f, key, in, out, grouping, child) =>
- execution.MapGroups(f, key, in, out, grouping, planLater(child)) :: Nil
- case logical.CoGroup(f, keyObj, lObj, rObj, out, lGroup, rGroup, left, right) =>
+ case logical.MapGroups(f, key, in, out, grouping, data, child) =>
+ execution.MapGroups(f, key, in, out, grouping, data, planLater(child)) :: Nil
+ case logical.CoGroup(f, keyObj, lObj, rObj, out, lGroup, rGroup, lAttr, rAttr, left, right) =>
execution.CoGroup(
- f, keyObj, lObj, rObj, out, lGroup, rGroup, planLater(left), planLater(right)) :: Nil
+ f, keyObj, lObj, rObj, out, lGroup, rGroup, lAttr, rAttr,
+ planLater(left), planLater(right)) :: Nil
case logical.Repartition(numPartitions, shuffle, child) =>
if (shuffle) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index 2acca1743c..582dda8603 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -53,14 +53,14 @@ trait ObjectOperator extends SparkPlan {
*/
case class MapPartitions(
func: Iterator[Any] => Iterator[Any],
- input: Expression,
+ deserializer: Expression,
serializer: Seq[NamedExpression],
child: SparkPlan) extends UnaryNode with ObjectOperator {
override def output: Seq[Attribute] = serializer.map(_.toAttribute)
override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitionsInternal { iter =>
- val getObject = generateToObject(input, child.output)
+ val getObject = generateToObject(deserializer, child.output)
val outputObject = generateToRow(serializer)
func(iter.map(getObject)).map(outputObject)
}
@@ -72,7 +72,7 @@ case class MapPartitions(
*/
case class AppendColumns(
func: Any => Any,
- input: Expression,
+ deserializer: Expression,
serializer: Seq[NamedExpression],
child: SparkPlan) extends UnaryNode with ObjectOperator {
@@ -82,7 +82,7 @@ case class AppendColumns(
override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitionsInternal { iter =>
- val getObject = generateToObject(input, child.output)
+ val getObject = generateToObject(deserializer, child.output)
val combiner = GenerateUnsafeRowJoiner.create(child.schema, newColumnSchema)
val outputObject = generateToRow(serializer)
@@ -103,10 +103,11 @@ case class AppendColumns(
*/
case class MapGroups(
func: (Any, Iterator[Any]) => TraversableOnce[Any],
- keyObject: Expression,
- input: Expression,
+ keyDeserializer: Expression,
+ valueDeserializer: Expression,
serializer: Seq[NamedExpression],
groupingAttributes: Seq[Attribute],
+ dataAttributes: Seq[Attribute],
child: SparkPlan) extends UnaryNode with ObjectOperator {
override def output: Seq[Attribute] = serializer.map(_.toAttribute)
@@ -121,8 +122,8 @@ case class MapGroups(
child.execute().mapPartitionsInternal { iter =>
val grouped = GroupedIterator(iter, groupingAttributes, child.output)
- val getKey = generateToObject(keyObject, groupingAttributes)
- val getValue = generateToObject(input, child.output)
+ val getKey = generateToObject(keyDeserializer, groupingAttributes)
+ val getValue = generateToObject(valueDeserializer, dataAttributes)
val outputObject = generateToRow(serializer)
grouped.flatMap { case (key, rowIter) =>
@@ -142,12 +143,14 @@ case class MapGroups(
*/
case class CoGroup(
func: (Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any],
- keyObject: Expression,
- leftObject: Expression,
- rightObject: Expression,
+ keyDeserializer: Expression,
+ leftDeserializer: Expression,
+ rightDeserializer: Expression,
serializer: Seq[NamedExpression],
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
+ leftAttr: Seq[Attribute],
+ rightAttr: Seq[Attribute],
left: SparkPlan,
right: SparkPlan) extends BinaryNode with ObjectOperator {
@@ -164,9 +167,9 @@ case class CoGroup(
val leftGrouped = GroupedIterator(leftData, leftGroup, left.output)
val rightGrouped = GroupedIterator(rightData, rightGroup, right.output)
- val getKey = generateToObject(keyObject, leftGroup)
- val getLeft = generateToObject(leftObject, left.output)
- val getRight = generateToObject(rightObject, right.output)
+ val getKey = generateToObject(keyDeserializer, leftGroup)
+ val getLeft = generateToObject(leftDeserializer, leftAttr)
+ val getRight = generateToObject(rightDeserializer, rightAttr)
val outputObject = generateToRow(serializer)
new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index b69bb21db5..374f4320a9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp}
import scala.language.postfixOps
+import org.apache.spark.sql.catalyst.encoders.OuterScopes
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
@@ -45,13 +46,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
1, 1, 1)
}
-
test("SPARK-12404: Datatype Helper Serializablity") {
val ds = sparkContext.parallelize((
- new Timestamp(0),
- new Date(0),
- java.math.BigDecimal.valueOf(1),
- scala.math.BigDecimal(1)) :: Nil).toDS()
+ new Timestamp(0),
+ new Date(0),
+ java.math.BigDecimal.valueOf(1),
+ scala.math.BigDecimal(1)) :: Nil).toDS()
ds.collect()
}
@@ -523,7 +523,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("verify mismatching field names fail with a good error") {
val ds = Seq(ClassData("a", 1)).toDS()
val e = intercept[AnalysisException] {
- ds.as[ClassData2].collect()
+ ds.as[ClassData2]
}
assert(e.getMessage.contains("cannot resolve 'c' given input columns: [a, b]"), e.getMessage)
}
@@ -567,6 +567,58 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkAnswer(ds1, DeepNestedStruct(NestedStruct(null)))
checkAnswer(ds1.toDF(), Row(Row(null)))
}
+
+ test("support inner class in Dataset") {
+ val outer = new OuterClass
+ OuterScopes.addOuterScope(outer)
+ val ds = Seq(outer.InnerClass("1"), outer.InnerClass("2")).toDS()
+ checkAnswer(ds.map(_.a), "1", "2")
+ }
+
+ test("grouping key and grouped value has field with same name") {
+ val ds = Seq(ClassData("a", 1), ClassData("a", 2)).toDS()
+ val agged = ds.groupBy(d => ClassNullableData(d.a, null)).mapGroups {
+ case (key, values) => key.a + values.map(_.b).sum
+ }
+
+ checkAnswer(agged, "a3")
+ }
+
+ test("cogroup's left and right side has field with same name") {
+ val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()
+ val right = Seq(ClassNullableData("a", 3), ClassNullableData("b", 4)).toDS()
+ val cogrouped = left.groupBy(_.a).cogroup(right.groupBy(_.a)) {
+ case (key, lData, rData) => Iterator(key + lData.map(_.b).sum + rData.map(_.b.toInt).sum)
+ }
+
+ checkAnswer(cogrouped, "a13", "b24")
+ }
+
+ test("give nice error message when the real number of fields doesn't match encoder schema") {
+ val ds = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()
+
+ val message = intercept[AnalysisException] {
+ ds.as[(String, Int, Long)]
+ }.message
+ assert(message ==
+ "Try to map struct<a:string,b:int> to Tuple3, " +
+ "but failed as the number of fields does not line up.\n" +
+ " - Input schema: struct<a:string,b:int>\n" +
+ " - Target schema: struct<_1:string,_2:int,_3:bigint>")
+
+ val message2 = intercept[AnalysisException] {
+ ds.as[Tuple1[String]]
+ }.message
+ assert(message2 ==
+ "Try to map struct<a:string,b:int> to Tuple1, " +
+ "but failed as the number of fields does not line up.\n" +
+ " - Input schema: struct<a:string,b:int>\n" +
+ " - Target schema: struct<_1:string>")
+ }
+}
+
+class OuterClass extends Serializable {
+ case class InnerClass(a: String)
}
case class ClassData(a: String, b: Int)