aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-04-05 10:53:54 -0700
committerMichael Armbrust <michael@databricks.com>2016-04-05 10:53:54 -0700
commitf77f11c67125fdac2e6849a4d45d9286fc872ed9 (patch)
tree8ebb3c3af583ac63a8b396c20a1949b75ef47711
parente4bd50412043c1ed2816406ba8d2af4f775ee3cf (diff)
downloadspark-f77f11c67125fdac2e6849a4d45d9286fc872ed9.tar.gz
spark-f77f11c67125fdac2e6849a4d45d9286fc872ed9.tar.bz2
spark-f77f11c67125fdac2e6849a4d45d9286fc872ed9.zip
[SPARK-14345][SQL] Decouple deserializer expression resolution from ObjectOperator
## What changes were proposed in this pull request? This PR decouples deserializer expression resolution from `ObjectOperator`, so that we can use deserializer expression in normal operators. This is needed by #12061 and #12067 , I abstracted the logic out and put them in this PR to reduce code change in the future. ## How was this patch tested? existing tests. Author: Wenchen Fan <wenchen@databricks.com> Closes #12131 from cloud-fan/separate.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala183
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala22
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala52
5 files changed, 153 insertions, 126 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index a6e317ebf0..3e0a6d29b4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.analysis
-import java.lang.reflect.Modifier
-
import scala.annotation.tailrec
import scala.collection.mutable.ArrayBuffer
@@ -87,9 +85,11 @@ class Analyzer(
Batch("Resolution", fixedPoint,
ResolveRelations ::
ResolveReferences ::
+ ResolveDeserializer ::
+ ResolveNewInstance ::
+ ResolveUpCast ::
ResolveGroupingAnalytics ::
ResolvePivot ::
- ResolveUpCast ::
ResolveOrdinalInOrderByAndGroupBy ::
ResolveSortReferences ::
ResolveGenerate ::
@@ -499,18 +499,9 @@ class Analyzer(
Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child)
}
- // A special case for ObjectOperator, because the deserializer expressions in ObjectOperator
- // should be resolved by their corresponding attributes instead of children's output.
- case o: ObjectOperator if containsUnresolvedDeserializer(o.deserializers.map(_._1)) =>
- val deserializerToAttributes = o.deserializers.map {
- case (deserializer, attributes) => new TreeNodeRef(deserializer) -> attributes
- }.toMap
-
- o.transformExpressions {
- case expr => deserializerToAttributes.get(new TreeNodeRef(expr)).map { attributes =>
- resolveDeserializer(expr, attributes)
- }.getOrElse(expr)
- }
+ // Skips plan which contains deserializer expressions, as they should be resolved by another
+ // rule: ResolveDeserializer.
+ case plan if containsDeserializer(plan.expressions) => plan
case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
@@ -526,38 +517,6 @@ class Analyzer(
}
}
- private def containsUnresolvedDeserializer(exprs: Seq[Expression]): Boolean = {
- exprs.exists { expr =>
- !expr.resolved || expr.find(_.isInstanceOf[BoundReference]).isDefined
- }
- }
-
- def resolveDeserializer(
- deserializer: Expression,
- attributes: Seq[Attribute]): Expression = {
- val unbound = deserializer transform {
- case b: BoundReference => attributes(b.ordinal)
- }
-
- resolveExpression(unbound, LocalRelation(attributes), throws = true) transform {
- case n: NewInstance
- // If this is an inner class of another class, register the outer object in `OuterScopes`.
- // Note that static inner classes (e.g., inner classes within Scala objects) don't need
- // outer pointer registration.
- if n.outerPointer.isEmpty &&
- n.cls.isMemberClass &&
- !Modifier.isStatic(n.cls.getModifiers) =>
- val outer = OuterScopes.getOuterScope(n.cls)
- if (outer == null) {
- throw new AnalysisException(
- s"Unable to generate an encoder for inner class `${n.cls.getName}` without " +
- "access to the scope that this class was defined in.\n" +
- "Try moving this class out of its parent class.")
- }
- n.copy(outerPointer = Some(outer))
- }
- }
-
def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = {
expressions.map {
case a: Alias => Alias(a.child, a.name)(isGenerated = a.isGenerated)
@@ -623,6 +582,10 @@ class Analyzer(
}
}
+ private def containsDeserializer(exprs: Seq[Expression]): Boolean = {
+ exprs.exists(_.find(_.isInstanceOf[UnresolvedDeserializer]).isDefined)
+ }
+
protected[sql] def resolveExpression(
expr: Expression,
plan: LogicalPlan,
@@ -1475,7 +1438,94 @@ class Analyzer(
Project(projectList, Join(left, right, joinType, newCondition))
}
+ /**
+ * Replaces [[UnresolvedDeserializer]] with the deserialization expression that has been resolved
+ * to the given input attributes.
+ */
+ object ResolveDeserializer extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case p if !p.childrenResolved => p
+ case p if p.resolved => p
+ case p => p transformExpressions {
+ case UnresolvedDeserializer(deserializer, inputAttributes) =>
+ val inputs = if (inputAttributes.isEmpty) {
+ p.children.flatMap(_.output)
+ } else {
+ inputAttributes
+ }
+ val unbound = deserializer transform {
+ case b: BoundReference => inputs(b.ordinal)
+ }
+ resolveExpression(unbound, LocalRelation(inputs), throws = true)
+ }
+ }
+ }
+
+ /**
+ * Resolves [[NewInstance]] by finding and adding the outer scope to it if the object being
+ * constructed is an inner class.
+ */
+ object ResolveNewInstance extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case p if !p.childrenResolved => p
+ case p if p.resolved => p
+
+ case p => p transformExpressions {
+ case n: NewInstance if n.childrenResolved && !n.resolved =>
+ val outer = OuterScopes.getOuterScope(n.cls)
+ if (outer == null) {
+ throw new AnalysisException(
+ s"Unable to generate an encoder for inner class `${n.cls.getName}` without " +
+ "access to the scope that this class was defined in.\n" +
+ "Try moving this class out of its parent class.")
+ }
+ n.copy(outerPointer = Some(outer))
+ }
+ }
+ }
+
+ /**
+ * Replace the [[UpCast]] expression by [[Cast]], and throw exceptions if the cast may truncate.
+ */
+ object ResolveUpCast extends Rule[LogicalPlan] {
+ private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
+ throw new AnalysisException(s"Cannot up cast ${from.sql} from " +
+ s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" +
+ "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") +
+ "You can either add an explicit cast to the input data or choose a higher precision " +
+ "type of the field in the target object")
+ }
+
+ private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = {
+ val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from)
+ val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to)
+ toPrecedence > 0 && fromPrecedence > toPrecedence
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case p if !p.childrenResolved => p
+ case p if p.resolved => p
+
+ case p => p transformExpressions {
+ case u @ UpCast(child, _, _) if !child.resolved => u
+
+ case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match {
+ case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) =>
+ fail(child, to, walkedTypePath)
+ case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) =>
+ fail(child, to, walkedTypePath)
+ case (from, to) if illegalNumericPrecedence(from, to) =>
+ fail(child, to, walkedTypePath)
+ case (TimestampType, DateType) =>
+ fail(child, DateType, walkedTypePath)
+ case (StringType, to: NumericType) =>
+ fail(child, to, walkedTypePath)
+ case _ => Cast(child, dataType.asNullable)
+ }
+ }
+ }
+ }
}
/**
@@ -1560,45 +1610,6 @@ object CleanupAliases extends Rule[LogicalPlan] {
}
/**
- * Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate.
- */
-object ResolveUpCast extends Rule[LogicalPlan] {
- private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
- throw new AnalysisException(s"Cannot up cast ${from.sql} from " +
- s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" +
- "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") +
- "You can either add an explicit cast to the input data or choose a higher precision " +
- "type of the field in the target object")
- }
-
- private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = {
- val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from)
- val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to)
- toPrecedence > 0 && fromPrecedence > toPrecedence
- }
-
- def apply(plan: LogicalPlan): LogicalPlan = {
- plan transformAllExpressions {
- case u @ UpCast(child, _, _) if !child.resolved => u
-
- case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match {
- case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) =>
- fail(child, to, walkedTypePath)
- case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) =>
- fail(child, to, walkedTypePath)
- case (from, to) if illegalNumericPrecedence(from, to) =>
- fail(child, to, walkedTypePath)
- case (TimestampType, DateType) =>
- fail(child, DateType, walkedTypePath)
- case (StringType, to: NumericType) =>
- fail(child, to, walkedTypePath)
- case _ => Cast(child, dataType.asNullable)
- }
- }
- }
-}
-
-/**
* Maps a time column to multiple time windows using the Expand operator. Since it's non-trivial to
* figure out how many windows a time column can map to, we over-estimate the number of windows and
* filter out the rows where the time column is not inside the time window.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index e73d367a73..fbbf6302e9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -307,3 +307,25 @@ case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None)
override lazy val resolved = false
}
+
+/**
+ * Holds the deserializer expression and the attributes that are available during the resolution
+ * for it. Deserializer expression is a special kind of expression that is not always resolved by
+ * children output, but by given attributes, e.g. the `keyDeserializer` in `MapGroups` should be
+ * resolved by `groupingAttributes` instead of children output.
+ *
+ * @param deserializer The unresolved deserializer expression
+ * @param inputAttributes The input attributes used to resolve deserializer expression, can be empty
+ * if we want to resolve deserializer by children output.
+ */
+case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq[Attribute])
+ extends UnaryExpression with Unevaluable with NonSQLExpression {
+ // The input attributes used to resolve deserializer expression must be all resolved.
+ require(inputAttributes.forall(_.resolved), "Input attributes must all be resolved.")
+
+ override def child: Expression = deserializer
+ override def dataType: DataType = throw new UnresolvedException(this, "dataType")
+ override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
+ override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
+ override lazy val resolved = false
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 1c712fde26..56d29cfbe1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -24,7 +24,7 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag}
import org.apache.spark.sql.{AnalysisException, Encoder}
import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection}
-import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
@@ -317,11 +317,11 @@ case class ExpressionEncoder[T](
def resolve(
schema: Seq[Attribute],
outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = {
- val resolved = SimpleAnalyzer.ResolveReferences.resolveDeserializer(deserializer, schema)
-
// Make a fake plan to wrap the deserializer, so that we can go though the whole analyzer, check
// analysis, go through optimizer, etc.
- val plan = Project(Alias(resolved, "")() :: Nil, LocalRelation(schema))
+ val plan = Project(
+ Alias(UnresolvedDeserializer(deserializer, schema), "")() :: Nil,
+ LocalRelation(schema))
val analyzedPlan = SimpleAnalyzer.execute(plan)
SimpleAnalyzer.checkAnalysis(analyzedPlan)
copy(deserializer = SimplifyCasts(analyzedPlan).expressions.head.children.head)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 07b67a0240..eebd43dae9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.expressions
+import java.lang.reflect.Modifier
+
import scala.annotation.tailrec
import scala.language.existentials
import scala.reflect.ClassTag
@@ -112,7 +114,7 @@ case class Invoke(
arguments: Seq[Expression] = Nil) extends Expression with NonSQLExpression {
override def nullable: Boolean = true
- override def children: Seq[Expression] = arguments.+:(targetObject)
+ override def children: Seq[Expression] = targetObject +: arguments
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
@@ -214,6 +216,16 @@ case class NewInstance(
override def children: Seq[Expression] = arguments
+ override lazy val resolved: Boolean = {
+ // If the class to construct is an inner class, we need to get its outer pointer, or this
+ // expression should be regarded as unresolved.
+ // Note that static inner classes (e.g., inner classes within Scala objects) don't need
+ // outer pointer registration.
+ val needOuterPointer =
+ outerPointer.isEmpty && cls.isMemberClass && !Modifier.isStatic(cls.getModifiers)
+ childrenResolved && !needOuterPointer
+ }
+
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 058fb6bff1..58313c7b72 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{ObjectType, StructType}
@@ -33,13 +34,6 @@ trait ObjectOperator extends LogicalPlan {
override def output: Seq[Attribute] = serializer.map(_.toAttribute)
/**
- * An [[ObjectOperator]] may have one or more deserializers to convert internal rows to objects.
- * It must also provide the attributes that are available during the resolution of each
- * deserializer.
- */
- def deserializers: Seq[(Expression, Seq[Attribute])]
-
- /**
* The object type that is produced by the user defined function. Note that the return type here
* is the same whether or not the operator is output serialized data.
*/
@@ -71,7 +65,7 @@ object MapPartitions {
child: LogicalPlan): MapPartitions = {
MapPartitions(
func.asInstanceOf[Iterator[Any] => Iterator[Any]],
- encoderFor[T].deserializer,
+ UnresolvedDeserializer(encoderFor[T].deserializer, Nil),
encoderFor[U].namedExpressions,
child)
}
@@ -87,9 +81,7 @@ case class MapPartitions(
func: Iterator[Any] => Iterator[Any],
deserializer: Expression,
serializer: Seq[NamedExpression],
- child: LogicalPlan) extends UnaryNode with ObjectOperator {
- override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output)
-}
+ child: LogicalPlan) extends UnaryNode with ObjectOperator
/** Factory for constructing new `AppendColumn` nodes. */
object AppendColumns {
@@ -98,7 +90,7 @@ object AppendColumns {
child: LogicalPlan): AppendColumns = {
new AppendColumns(
func.asInstanceOf[Any => Any],
- encoderFor[T].deserializer,
+ UnresolvedDeserializer(encoderFor[T].deserializer, Nil),
encoderFor[U].namedExpressions,
child)
}
@@ -120,8 +112,6 @@ case class AppendColumns(
override def output: Seq[Attribute] = child.output ++ newColumns
def newColumns: Seq[Attribute] = serializer.map(_.toAttribute)
-
- override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output)
}
/** Factory for constructing new `MapGroups` nodes. */
@@ -133,8 +123,8 @@ object MapGroups {
child: LogicalPlan): MapGroups = {
new MapGroups(
func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]],
- encoderFor[K].deserializer,
- encoderFor[T].deserializer,
+ UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
+ UnresolvedDeserializer(encoderFor[T].deserializer, dataAttributes),
encoderFor[U].namedExpressions,
groupingAttributes,
dataAttributes,
@@ -158,11 +148,7 @@ case class MapGroups(
serializer: Seq[NamedExpression],
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
- child: LogicalPlan) extends UnaryNode with ObjectOperator {
-
- override def deserializers: Seq[(Expression, Seq[Attribute])] =
- Seq(keyDeserializer -> groupingAttributes, valueDeserializer -> dataAttributes)
-}
+ child: LogicalPlan) extends UnaryNode with ObjectOperator
/** Factory for constructing new `CoGroup` nodes. */
object CoGroup {
@@ -170,22 +156,24 @@ object CoGroup {
func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
- leftData: Seq[Attribute],
- rightData: Seq[Attribute],
+ leftAttr: Seq[Attribute],
+ rightAttr: Seq[Attribute],
left: LogicalPlan,
right: LogicalPlan): CoGroup = {
require(StructType.fromAttributes(leftGroup) == StructType.fromAttributes(rightGroup))
CoGroup(
func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]],
- encoderFor[Key].deserializer,
- encoderFor[Left].deserializer,
- encoderFor[Right].deserializer,
+ // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to
+ // resolve the `keyDeserializer` based on either of them, here we pick the left one.
+ UnresolvedDeserializer(encoderFor[Key].deserializer, leftGroup),
+ UnresolvedDeserializer(encoderFor[Left].deserializer, leftAttr),
+ UnresolvedDeserializer(encoderFor[Right].deserializer, rightAttr),
encoderFor[Result].namedExpressions,
leftGroup,
rightGroup,
- leftData,
- rightData,
+ leftAttr,
+ rightAttr,
left,
right)
}
@@ -206,10 +194,4 @@ case class CoGroup(
leftAttr: Seq[Attribute],
rightAttr: Seq[Attribute],
left: LogicalPlan,
- right: LogicalPlan) extends BinaryNode with ObjectOperator {
-
- override def deserializers: Seq[(Expression, Seq[Attribute])] =
- // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to resolve
- // the `keyDeserializer` based on either of them, here we pick the left one.
- Seq(keyDeserializer -> leftGroup, leftDeserializer -> leftAttr, rightDeserializer -> rightAttr)
-}
+ right: LogicalPlan) extends BinaryNode with ObjectOperator