aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/main/scala/org/apache
diff options
context:
space:
mode:
authorSean Zhong <seanzhong@databricks.com>2016-08-09 08:36:50 +0800
committerWenchen Fan <wenchen@databricks.com>2016-08-09 08:36:50 +0800
commitbca43cd63503eb5287151c5d9ca6ccd8cd13fbc8 (patch)
tree922dc9ae9533127715ab4a415f0b10518991901a /sql/catalyst/src/main/scala/org/apache
parentdf10658831f4e5f9756a5732673ad12904b5d05c (diff)
downloadspark-bca43cd63503eb5287151c5d9ca6ccd8cd13fbc8.tar.gz
spark-bca43cd63503eb5287151c5d9ca6ccd8cd13fbc8.tar.bz2
spark-bca43cd63503eb5287151c5d9ca6ccd8cd13fbc8.zip
[SPARK-16898][SQL] Adds argument type information for typed logical plan like MapElements, TypedFilter, and AppendColumn
## What changes were proposed in this pull request? This PR adds argument type information for typed logical plan like MapElements, TypedFilter, and AppendColumn, so that we can use these info in customized optimizer rule. ## How was this patch tested? Existing test. Author: Sean Zhong <seanzhong@databricks.com> Closes #14494 from clockfly/add_more_info_for_typed_operator.
Diffstat (limited to 'sql/catalyst/src/main/scala/org/apache')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala17
2 files changed, 25 insertions, 5 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 75130007b9..e34a478818 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -214,7 +214,7 @@ object EliminateSerialization extends Rule[LogicalPlan] {
val objAttr = Alias(s.inputObjAttr, s.inputObjAttr.name)(exprId = d.outputObjAttr.exprId)
Project(objAttr :: Nil, s.child)
- case a @ AppendColumns(_, _, _, s: SerializeFromObject)
+ case a @ AppendColumns(_, _, _, _, _, s: SerializeFromObject)
if a.deserializer.dataType == s.inputObjAttr.dataType =>
AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child)
@@ -223,7 +223,7 @@ object EliminateSerialization extends Rule[LogicalPlan] {
// deserialization in condition, and push it down through `SerializeFromObject`.
// e.g. `ds.map(...).filter(...)` can be optimized by this rule to save extra deserialization,
// but `ds.map(...).as[AnotherType].filter(...)` can not be optimized.
- case f @ TypedFilter(_, _, s: SerializeFromObject)
+ case f @ TypedFilter(_, _, _, _, s: SerializeFromObject)
if f.deserializer.dataType == s.inputObjAttr.dataType =>
s.copy(child = f.withObjectProducerChild(s.child))
@@ -1703,9 +1703,14 @@ case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends Rule[Logic
*/
object CombineTypedFilters extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case t1 @ TypedFilter(_, _, t2 @ TypedFilter(_, _, child))
+ case t1 @ TypedFilter(_, _, _, _, t2 @ TypedFilter(_, _, _, _, child))
if t1.deserializer.dataType == t2.deserializer.dataType =>
- TypedFilter(combineFilterFunction(t2.func, t1.func), t1.deserializer, child)
+ TypedFilter(
+ combineFilterFunction(t2.func, t1.func),
+ t1.argumentClass,
+ t1.argumentSchema,
+ t1.deserializer,
+ child)
}
private def combineFilterFunction(func1: AnyRef, func2: AnyRef): Any => Boolean = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index e1890edcbb..fefe5a3953 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -155,6 +155,8 @@ object MapElements {
val deserialized = CatalystSerde.deserialize[T](child)
val mapped = MapElements(
func,
+ implicitly[Encoder[T]].clsTag.runtimeClass,
+ implicitly[Encoder[T]].schema,
CatalystSerde.generateObjAttr[U],
deserialized)
CatalystSerde.serialize[U](mapped)
@@ -166,12 +168,19 @@ object MapElements {
*/
case class MapElements(
func: AnyRef,
+ argumentClass: Class[_],
+ argumentSchema: StructType,
outputObjAttr: Attribute,
child: LogicalPlan) extends ObjectConsumer with ObjectProducer
object TypedFilter {
def apply[T : Encoder](func: AnyRef, child: LogicalPlan): TypedFilter = {
- TypedFilter(func, UnresolvedDeserializer(encoderFor[T].deserializer), child)
+ TypedFilter(
+ func,
+ implicitly[Encoder[T]].clsTag.runtimeClass,
+ implicitly[Encoder[T]].schema,
+ UnresolvedDeserializer(encoderFor[T].deserializer),
+ child)
}
}
@@ -186,6 +195,8 @@ object TypedFilter {
*/
case class TypedFilter(
func: AnyRef,
+ argumentClass: Class[_],
+ argumentSchema: StructType,
deserializer: Expression,
child: LogicalPlan) extends UnaryNode {
@@ -213,6 +224,8 @@ object AppendColumns {
child: LogicalPlan): AppendColumns = {
new AppendColumns(
func.asInstanceOf[Any => Any],
+ implicitly[Encoder[T]].clsTag.runtimeClass,
+ implicitly[Encoder[T]].schema,
UnresolvedDeserializer(encoderFor[T].deserializer),
encoderFor[U].namedExpressions,
child)
@@ -228,6 +241,8 @@ object AppendColumns {
*/
case class AppendColumns(
func: Any => Any,
+ argumentClass: Class[_],
+ argumentSchema: StructType,
deserializer: Expression,
serializer: Seq[NamedExpression],
child: LogicalPlan) extends UnaryNode {