aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala17
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala8
4 files changed, 31 insertions, 11 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 75130007b9..e34a478818 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -214,7 +214,7 @@ object EliminateSerialization extends Rule[LogicalPlan] {
val objAttr = Alias(s.inputObjAttr, s.inputObjAttr.name)(exprId = d.outputObjAttr.exprId)
Project(objAttr :: Nil, s.child)
- case a @ AppendColumns(_, _, _, s: SerializeFromObject)
+ case a @ AppendColumns(_, _, _, _, _, s: SerializeFromObject)
if a.deserializer.dataType == s.inputObjAttr.dataType =>
AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child)
@@ -223,7 +223,7 @@ object EliminateSerialization extends Rule[LogicalPlan] {
// deserialization in condition, and push it down through `SerializeFromObject`.
// e.g. `ds.map(...).filter(...)` can be optimized by this rule to save extra deserialization,
// but `ds.map(...).as[AnotherType].filter(...)` can not be optimized.
- case f @ TypedFilter(_, _, s: SerializeFromObject)
+ case f @ TypedFilter(_, _, _, _, s: SerializeFromObject)
if f.deserializer.dataType == s.inputObjAttr.dataType =>
s.copy(child = f.withObjectProducerChild(s.child))
@@ -1703,9 +1703,14 @@ case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends Rule[Logic
*/
object CombineTypedFilters extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case t1 @ TypedFilter(_, _, t2 @ TypedFilter(_, _, child))
+ case t1 @ TypedFilter(_, _, _, _, t2 @ TypedFilter(_, _, _, _, child))
if t1.deserializer.dataType == t2.deserializer.dataType =>
- TypedFilter(combineFilterFunction(t2.func, t1.func), t1.deserializer, child)
+ TypedFilter(
+ combineFilterFunction(t2.func, t1.func),
+ t1.argumentClass,
+ t1.argumentSchema,
+ t1.deserializer,
+ child)
}
private def combineFilterFunction(func1: AnyRef, func2: AnyRef): Any => Boolean = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index e1890edcbb..fefe5a3953 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -155,6 +155,8 @@ object MapElements {
val deserialized = CatalystSerde.deserialize[T](child)
val mapped = MapElements(
func,
+ implicitly[Encoder[T]].clsTag.runtimeClass,
+ implicitly[Encoder[T]].schema,
CatalystSerde.generateObjAttr[U],
deserialized)
CatalystSerde.serialize[U](mapped)
@@ -166,12 +168,19 @@ object MapElements {
*/
case class MapElements(
func: AnyRef,
+ argumentClass: Class[_],
+ argumentSchema: StructType,
outputObjAttr: Attribute,
child: LogicalPlan) extends ObjectConsumer with ObjectProducer
object TypedFilter {
def apply[T : Encoder](func: AnyRef, child: LogicalPlan): TypedFilter = {
- TypedFilter(func, UnresolvedDeserializer(encoderFor[T].deserializer), child)
+ TypedFilter(
+ func,
+ implicitly[Encoder[T]].clsTag.runtimeClass,
+ implicitly[Encoder[T]].schema,
+ UnresolvedDeserializer(encoderFor[T].deserializer),
+ child)
}
}
@@ -186,6 +195,8 @@ object TypedFilter {
*/
case class TypedFilter(
func: AnyRef,
+ argumentClass: Class[_],
+ argumentSchema: StructType,
deserializer: Expression,
child: LogicalPlan) extends UnaryNode {
@@ -213,6 +224,8 @@ object AppendColumns {
child: LogicalPlan): AppendColumns = {
new AppendColumns(
func.asInstanceOf[Any => Any],
+ implicitly[Encoder[T]].clsTag.runtimeClass,
+ implicitly[Encoder[T]].schema,
UnresolvedDeserializer(encoderFor[T].deserializer),
encoderFor[U].namedExpressions,
child)
@@ -228,6 +241,8 @@ object AppendColumns {
*/
case class AppendColumns(
func: Any => Any,
+ argumentClass: Class[_],
+ argumentSchema: StructType,
deserializer: Expression,
serializer: Seq[NamedExpression],
child: LogicalPlan) extends UnaryNode {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index fb08e1228e..4dfec3ec85 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -356,9 +356,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.FlatMapGroupsInR(f, p, b, is, os, key, value, grouping, data, objAttr, child) =>
execution.FlatMapGroupsInRExec(f, p, b, is, os, key, value, grouping,
data, objAttr, planLater(child)) :: Nil
- case logical.MapElements(f, objAttr, child) =>
+ case logical.MapElements(f, _, _, objAttr, child) =>
execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil
- case logical.AppendColumns(f, in, out, child) =>
+ case logical.AppendColumns(f, _, _, in, out, child) =>
execution.AppendColumnsExec(f, in, out, planLater(child)) :: Nil
case logical.AppendColumnsWithObject(f, childSer, newSer, child) =>
execution.AppendColumnsWithObjectExec(f, childSer, newSer, planLater(child)) :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
index c39a78da6f..1dae5f6964 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.expressions.Aggregator
////////////////////////////////////////////////////////////////////////////////////////////////////
-class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double] {
+class TypedSumDouble[IN](val f: IN => Double) extends Aggregator[IN, Double, Double] {
override def zero: Double = 0.0
override def reduce(b: Double, a: IN): Double = b + f(a)
override def merge(b1: Double, b2: Double): Double = b1 + b2
@@ -45,7 +45,7 @@ class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double]
}
-class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] {
+class TypedSumLong[IN](val f: IN => Long) extends Aggregator[IN, Long, Long] {
override def zero: Long = 0L
override def reduce(b: Long, a: IN): Long = b + f(a)
override def merge(b1: Long, b2: Long): Long = b1 + b2
@@ -63,7 +63,7 @@ class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] {
}
-class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] {
+class TypedCount[IN](val f: IN => Any) extends Aggregator[IN, Long, Long] {
override def zero: Long = 0
override def reduce(b: Long, a: IN): Long = {
if (f(a) == null) b else b + 1
@@ -82,7 +82,7 @@ class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] {
}
-class TypedAverage[IN](f: IN => Double) extends Aggregator[IN, (Double, Long), Double] {
+class TypedAverage[IN](val f: IN => Double) extends Aggregator[IN, (Double, Long), Double] {
override def zero: (Double, Long) = (0.0, 0L)
override def reduce(b: (Double, Long), a: IN): (Double, Long) = (f(a) + b._1, 1 + b._2)
override def finish(reduction: (Double, Long)): Double = reduction._1 / reduction._2