diff options
Diffstat (limited to 'sql')
3 files changed, 15 insertions, 32 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3fdc6d62bc..891408e310 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -141,7 +141,8 @@ class Analyzer( child match { case _: UnresolvedAttribute => u case ne: NamedExpression => ne - case ev: ExtractValueWithStruct => Alias(ev, ev.field.name)() + case g: GetStructField => Alias(g, g.field.name)() + case g: GetArrayStructFields => Alias(g, g.field.name)() case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g, Nil) case e if !e.resolved => u case other => Alias(other, s"_c$i")() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 73cc930c45..5504781edc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -78,12 +78,6 @@ object ExtractValue { } } - def unapply(g: ExtractValue): Option[(Expression, Expression)] = g match { - case o: GetArrayItem => Some((o.child, o.ordinal)) - case o: GetMapValue => Some((o.child, o.key)) - case s: ExtractValueWithStruct => Some((s.child, null)) - } - /** * Find the ordinal of StructField, report error if no desired field or over one * desired fields are found. @@ -104,31 +98,16 @@ object ExtractValue { } /** - * A common interface of all kinds of extract value expressions. - * Note: concrete extract value expressions are created only by `ExtractValue.apply`, - * we don't need to do type check for them. - */ -trait ExtractValue { - self: Expression => -} - -abstract class ExtractValueWithStruct extends UnaryExpression with ExtractValue { - self: Product => - - def field: StructField - override def toString: String = s"$child.${field.name}" -} - -/** * Returns the value of fields in the Struct `child`. * * No need to do type checking since it is handled by [[ExtractValue]]. */ case class GetStructField(child: Expression, field: StructField, ordinal: Int) - extends ExtractValueWithStruct { + extends UnaryExpression { override def dataType: DataType = field.dataType override def nullable: Boolean = child.nullable || field.nullable + override def toString: String = s"$child.${field.name}" protected override def nullSafeEval(input: Any): Any = input.asInstanceOf[InternalRow](ordinal) @@ -155,10 +134,11 @@ case class GetArrayStructFields( child: Expression, field: StructField, ordinal: Int, - containsNull: Boolean) extends ExtractValueWithStruct { + containsNull: Boolean) extends UnaryExpression { override def dataType: DataType = ArrayType(field.dataType, containsNull) override def nullable: Boolean = child.nullable || containsNull || field.nullable + override def toString: String = s"$child.${field.name}" protected override def nullSafeEval(input: Any): Any = { input.asInstanceOf[Seq[InternalRow]].map { row => @@ -191,8 +171,7 @@ case class GetArrayStructFields( * * No need to do type checking since it is handled by [[ExtractValue]]. */ -case class GetArrayItem(child: Expression, ordinal: Expression) - extends BinaryExpression with ExtractValue { +case class GetArrayItem(child: Expression, ordinal: Expression) extends BinaryExpression { override def toString: String = s"$child[$ordinal]" @@ -231,12 +210,11 @@ case class GetArrayItem(child: Expression, ordinal: Expression) } /** - * Returns the value of key `ordinal` in Map `child`. + * Returns the value of key `key` in Map `child`. * * No need to do type checking since it is handled by [[ExtractValue]]. */ -case class GetMapValue(child: Expression, key: Expression) - extends BinaryExpression with ExtractValue { +case class GetMapValue(child: Expression, key: Expression) extends BinaryExpression { override def toString: String = s"$child[$key]" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 7d41ef9aaf..5d80214abf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -275,8 +275,12 @@ object NullPropagation extends Rule[LogicalPlan] { case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType) case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) - case e @ ExtractValue(Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ ExtractValue(_, Literal(null, _)) => Literal.create(null, e.dataType) + case e @ GetArrayItem(Literal(null, _), _) => Literal.create(null, e.dataType) + case e @ GetArrayItem(_, Literal(null, _)) => Literal.create(null, e.dataType) + case e @ GetMapValue(Literal(null, _), _) => Literal.create(null, e.dataType) + case e @ GetMapValue(_, Literal(null, _)) => Literal.create(null, e.dataType) + case e @ GetStructField(Literal(null, _), _, _) => Literal.create(null, e.dataType) + case e @ GetArrayStructFields(Literal(null, _), _, _, _) => Literal.create(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) case e @ Count(expr) if !expr.nullable => Count(Literal(1)) |