From 2d05f325dc3c70349bd17ed399897f22d967c687 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 8 May 2015 11:49:38 -0700 Subject: [SPARK-7133] [SQL] Implement struct, array, and map field accessor It's the first step: generalize UnresolvedGetField to support all map, struct, and array TODO: add `apply` in Scala and `__getitem__` in Python, and unify the `getItem` and `getField` methods to one single API(or should we keep them for compatibility?). Author: Wenchen Fan Closes #5744 from cloud-fan/generalize and squashes the following commits: 715c589 [Wenchen Fan] address comments 7ea5b31 [Wenchen Fan] fix python test 4f0833a [Wenchen Fan] add python test f515d69 [Wenchen Fan] add apply method and test cases 8df6199 [Wenchen Fan] fix python test 239730c [Wenchen Fan] fix test compile 2a70526 [Wenchen Fan] use _bin_op in dataframe.py 6bf72bc [Wenchen Fan] address comments 3f880c3 [Wenchen Fan] add java doc ab35ab5 [Wenchen Fan] fix python test b5961a9 [Wenchen Fan] fix style c9d85f5 [Wenchen Fan] generalize UnresolvedGetField to support all map, struct, and array --- python/pyspark/sql/dataframe.py | 24 +-- python/pyspark/sql/tests.py | 7 + .../org/apache/spark/sql/catalyst/SqlParser.scala | 4 +- .../spark/sql/catalyst/analysis/Analyzer.scala | 4 +- .../spark/sql/catalyst/analysis/unresolved.scala | 14 +- .../apache/spark/sql/catalyst/dsl/package.scala | 7 +- .../sql/catalyst/expressions/ExtractValue.scala | 206 +++++++++++++++++++++ .../sql/catalyst/expressions/complexTypes.scala | 131 ------------- .../spark/sql/catalyst/optimizer/Optimizer.scala | 6 +- .../spark/sql/catalyst/planning/patterns.scala | 2 +- .../sql/catalyst/plans/logical/LogicalPlan.scala | 3 +- .../expressions/ExpressionEvaluationSuite.scala | 69 ++++--- .../catalyst/optimizer/ConstantFoldingSuite.scala | 8 +- .../main/scala/org/apache/spark/sql/Column.scala | 19 +- .../org/apache/spark/sql/DataFrameSuite.scala | 10 +- .../scala/org/apache/spark/sql/hive/HiveQl.scala | 4 +- 16 files changed, 327 insertions(+), 191 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index cee804f5cc..a9697999e8 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1275,7 +1275,7 @@ class Column(object): # container operators __contains__ = _bin_op("contains") - __getitem__ = _bin_op("getItem") + __getitem__ = _bin_op("apply") # bitwise operators bitwiseOR = _bin_op("bitwiseOR") @@ -1308,19 +1308,19 @@ class Column(object): >>> from pyspark.sql import Row >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF() >>> df.select(df.r.getField("b")).show() - +---+ - |r.b| - +---+ - | b| - +---+ + +----+ + |r[b]| + +----+ + | b| + +----+ >>> df.select(df.r.a).show() - +---+ - |r.a| - +---+ - | 1| - +---+ + +----+ + |r[a]| + +----+ + | 1| + +----+ """ - return Column(self._jc.getField(name)) + return self[name] def __getattr__(self, item): if item.startswith("__"): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 45dfedce22..7e63f4d646 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -519,6 +519,13 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual("v", df.select(df.d["k"]).first()[0]) self.assertEqual("v", df.select(df.d.getItem("k")).first()[0]) + def test_field_accessor(self): + df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF() + self.assertEqual(1, df.select(df.l[0]).first()[0]) + self.assertEqual(1, df.select(df.r["a"]).first()[0]) + self.assertEqual("b", df.select(df.r["b"]).first()[0]) + self.assertEqual("v", df.select(df.d["k"]).first()[0]) + def test_infer_long_type(self): longrow = [Row(f1='a', f2=100000000000000)] df = self.sc.parallelize(longrow).toDF() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index b06bfb2ce8..fc36b9f1f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -375,9 +375,9 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val primary: PackratParser[Expression] = ( literal | expression ~ ("[" ~> expression <~ "]") ^^ - { case base ~ ordinal => GetItem(base, ordinal) } + { case base ~ ordinal => UnresolvedExtractValue(base, ordinal) } | (expression <~ ".") ~ ident ^^ - { case base ~ fieldName => UnresolvedGetField(base, fieldName) } + { case base ~ fieldName => UnresolvedExtractValue(base, Literal(fieldName)) } | cast | "(" ~> expression <~ ")" | function diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index bb7913e186..ecbac57ea4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -348,8 +348,8 @@ class Analyzer( withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } logDebug(s"Resolving $u to $result") result - case UnresolvedGetField(child, fieldName) if child.resolved => - GetField(child, fieldName, resolver) + case UnresolvedExtractValue(child, fieldExpr) if child.resolved => + ExtractValue(child, fieldExpr, resolver) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index eb736ac329..2999c2ef3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -184,7 +184,17 @@ case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star { override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")") } -case class UnresolvedGetField(child: Expression, fieldName: String) extends UnaryExpression { +/** + * Extracts a value or values from an Expression + * + * @param child The expression to extract value from, + * can be Map, Array, Struct or array of Structs. + * @param extraction The expression to describe the extraction, + * can be key of Map, index of Array, field name of Struct. + */ +case class UnresolvedExtractValue(child: Expression, extraction: Expression) + extends UnaryExpression { + override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def foldable: Boolean = throw new UnresolvedException(this, "foldable") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") @@ -193,5 +203,5 @@ case class UnresolvedGetField(child: Expression, fieldName: String) extends Unar override def eval(input: Row = null): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString: String = s"$child.$fieldName" + override def toString: String = s"$child[$extraction]" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index fa6cc7a1a3..4c0d70203c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp} import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} -import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedGetField, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedExtractValue, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} @@ -100,8 +100,9 @@ package object dsl { def isNull: Predicate = IsNull(expr) def isNotNull: Predicate = IsNotNull(expr) - def getItem(ordinal: Expression): Expression = GetItem(expr, ordinal) - def getField(fieldName: String): UnresolvedGetField = UnresolvedGetField(expr, fieldName) + def getItem(ordinal: Expression): UnresolvedExtractValue = UnresolvedExtractValue(expr, ordinal) + def getField(fieldName: String): UnresolvedExtractValue = + UnresolvedExtractValue(expr, Literal(fieldName)) def cast(to: DataType): Expression = Cast(expr, to) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala new file mode 100644 index 0000000000..e05926cbfe --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.collection.Map + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.types._ + +object ExtractValue { + /** + * Returns the resolved `ExtractValue`. It will return one kind of concrete `ExtractValue`, + * depend on the type of `child` and `extraction`. + * + * `child` | `extraction` | concrete `ExtractValue` + * ---------------------------------------------------------------- + * Struct | Literal String | GetStructField + * Array[Struct] | Literal String | GetArrayStructFields + * Array | Integral type | GetArrayItem + * Map | Any type | GetMapValue + */ + def apply( + child: Expression, + extraction: Expression, + resolver: Resolver): ExtractValue = { + + (child.dataType, extraction) match { + case (StructType(fields), Literal(fieldName, StringType)) => + val ordinal = findField(fields, fieldName.toString, resolver) + GetStructField(child, fields(ordinal), ordinal) + case (ArrayType(StructType(fields), containsNull), Literal(fieldName, StringType)) => + val ordinal = findField(fields, fieldName.toString, resolver) + GetArrayStructFields(child, fields(ordinal), ordinal, containsNull) + case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] => + GetArrayItem(child, extraction) + case (_: MapType, _) => + GetMapValue(child, extraction) + case (otherType, _) => + val errorMsg = otherType match { + case StructType(_) | ArrayType(StructType(_), _) => + s"Field name should be String Literal, but it's $extraction" + case _: ArrayType => + s"Array index should be integral type, but it's ${extraction.dataType}" + case other => + s"Can't extract value from $child" + } + throw new AnalysisException(errorMsg) + } + } + + def unapply(g: ExtractValue): Option[(Expression, Expression)] = { + g match { + case o: ExtractValueWithOrdinal => Some((o.child, o.ordinal)) + case _ => Some((g.child, null)) + } + } + + /** + * Find the ordinal of StructField, report error if no desired field or over one + * desired fields are found. + */ + private def findField(fields: Array[StructField], fieldName: String, resolver: Resolver): Int = { + val checkField = (f: StructField) => resolver(f.name, fieldName) + val ordinal = fields.indexWhere(checkField) + if (ordinal == -1) { + throw new AnalysisException( + s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}") + } else if (fields.indexWhere(checkField, ordinal + 1) != -1) { + throw new AnalysisException( + s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}") + } else { + ordinal + } + } +} + +trait ExtractValue extends UnaryExpression { + self: Product => + + type EvaluatedType = Any +} + +/** + * Returns the value of fields in the Struct `child`. + */ +case class GetStructField(child: Expression, field: StructField, ordinal: Int) + extends ExtractValue { + + override def dataType: DataType = field.dataType + override def nullable: Boolean = child.nullable || field.nullable + override def foldable: Boolean = child.foldable + override def toString: String = s"$child.${field.name}" + + override def eval(input: Row): Any = { + val baseValue = child.eval(input).asInstanceOf[Row] + if (baseValue == null) null else baseValue(ordinal) + } +} + +/** + * Returns the array of value of fields in the Array of Struct `child`. + */ +case class GetArrayStructFields( + child: Expression, + field: StructField, + ordinal: Int, + containsNull: Boolean) extends ExtractValue { + + override def dataType: DataType = ArrayType(field.dataType, containsNull) + override def nullable: Boolean = child.nullable + override def foldable: Boolean = child.foldable + override def toString: String = s"$child.${field.name}" + + override def eval(input: Row): Any = { + val baseValue = child.eval(input).asInstanceOf[Seq[Row]] + if (baseValue == null) null else { + baseValue.map { row => + if (row == null) null else row(ordinal) + } + } + } +} + +abstract class ExtractValueWithOrdinal extends ExtractValue { + self: Product => + + def ordinal: Expression + + /** `Null` is returned for invalid ordinals. */ + override def nullable: Boolean = true + override def foldable: Boolean = child.foldable && ordinal.foldable + override def toString: String = s"$child[$ordinal]" + override def children: Seq[Expression] = child :: ordinal :: Nil + + override def eval(input: Row): Any = { + val value = child.eval(input) + if (value == null) { + null + } else { + val o = ordinal.eval(input) + if (o == null) { + null + } else { + evalNotNull(value, o) + } + } + } + + protected def evalNotNull(value: Any, ordinal: Any): Any +} + +/** + * Returns the field at `ordinal` in the Array `child` + */ +case class GetArrayItem(child: Expression, ordinal: Expression) + extends ExtractValueWithOrdinal { + + override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType + + override lazy val resolved = childrenResolved && + child.dataType.isInstanceOf[ArrayType] && ordinal.dataType.isInstanceOf[IntegralType] + + protected def evalNotNull(value: Any, ordinal: Any) = { + // TODO: consider using Array[_] for ArrayType child to avoid + // boxing of primitives + val baseValue = value.asInstanceOf[Seq[_]] + val index = ordinal.asInstanceOf[Int] + if (index >= baseValue.size || index < 0) { + null + } else { + baseValue(index) + } + } +} + +/** + * Returns the value of key `ordinal` in Map `child` + */ +case class GetMapValue(child: Expression, ordinal: Expression) + extends ExtractValueWithOrdinal { + + override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType + + override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[MapType] + + protected def evalNotNull(value: Any, ordinal: Any) = { + val baseValue = value.asInstanceOf[Map[Any, _]] + baseValue.get(ordinal).orNull + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index fc1f696559..956a2429b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -17,139 +17,8 @@ package org.apache.spark.sql.catalyst.expressions -import scala.collection.Map - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.types._ -/** - * Returns the item at `ordinal` in the Array `child` or the Key `ordinal` in Map `child`. - */ -case class GetItem(child: Expression, ordinal: Expression) extends Expression { - type EvaluatedType = Any - - val children: Seq[Expression] = child :: ordinal :: Nil - /** `Null` is returned for invalid ordinals. */ - override def nullable: Boolean = true - override def foldable: Boolean = child.foldable && ordinal.foldable - - override def dataType: DataType = child.dataType match { - case ArrayType(dt, _) => dt - case MapType(_, vt, _) => vt - } - override lazy val resolved = - childrenResolved && - (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) - - override def toString: String = s"$child[$ordinal]" - - override def eval(input: Row): Any = { - val value = child.eval(input) - if (value == null) { - null - } else { - val key = ordinal.eval(input) - if (key == null) { - null - } else { - if (child.dataType.isInstanceOf[ArrayType]) { - // TODO: consider using Array[_] for ArrayType child to avoid - // boxing of primitives - val baseValue = value.asInstanceOf[Seq[_]] - val o = key.asInstanceOf[Int] - if (o >= baseValue.size || o < 0) { - null - } else { - baseValue(o) - } - } else { - val baseValue = value.asInstanceOf[Map[Any, _]] - baseValue.get(key).orNull - } - } - } - } -} - - -trait GetField extends UnaryExpression { - self: Product => - - type EvaluatedType = Any - override def foldable: Boolean = child.foldable - override def toString: String = s"$child.${field.name}" - - def field: StructField -} - -object GetField { - /** - * Returns the resolved `GetField`, and report error if no desired field or over one - * desired fields are found. - */ - def apply( - expr: Expression, - fieldName: String, - resolver: Resolver): GetField = { - def findField(fields: Array[StructField]): Int = { - val checkField = (f: StructField) => resolver(f.name, fieldName) - val ordinal = fields.indexWhere(checkField) - if (ordinal == -1) { - throw new AnalysisException( - s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}") - } else if (fields.indexWhere(checkField, ordinal + 1) != -1) { - throw new AnalysisException( - s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}") - } else { - ordinal - } - } - expr.dataType match { - case StructType(fields) => - val ordinal = findField(fields) - StructGetField(expr, fields(ordinal), ordinal) - case ArrayType(StructType(fields), containsNull) => - val ordinal = findField(fields) - ArrayGetField(expr, fields(ordinal), ordinal, containsNull) - case otherType => - throw new AnalysisException(s"GetField is not valid on fields of type $otherType") - } - } -} - -/** - * Returns the value of fields in the Struct `child`. - */ -case class StructGetField(child: Expression, field: StructField, ordinal: Int) extends GetField { - - override def dataType: DataType = field.dataType - override def nullable: Boolean = child.nullable || field.nullable - - override def eval(input: Row): Any = { - val baseValue = child.eval(input).asInstanceOf[Row] - if (baseValue == null) null else baseValue(ordinal) - } -} - -/** - * Returns the array of value of fields in the Array of Struct `child`. - */ -case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, containsNull: Boolean) - extends GetField { - - override def dataType: DataType = ArrayType(field.dataType, containsNull) - override def nullable: Boolean = child.nullable - - override def eval(input: Row): Any = { - val baseValue = child.eval(input).asInstanceOf[Seq[Row]] - if (baseValue == null) null else { - baseValue.map { row => - if (row == null) null else row(ordinal) - } - } - } -} /** * Returns an Array containing the evaluation of all children expressions. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e4a60f53d6..d7b2f203a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -227,10 +227,8 @@ object NullPropagation extends Rule[LogicalPlan] { case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType) case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) - case e @ GetItem(Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ GetItem(_, Literal(null, _)) => Literal.create(null, e.dataType) - case e @ StructGetField(Literal(null, _), _, _) => Literal.create(null, e.dataType) - case e @ ArrayGetField(Literal(null, _), _, _, _) => Literal.create(null, e.dataType) + case e @ ExtractValue(Literal(null, _), _) => Literal.create(null, e.dataType) + case e @ ExtractValue(_, Literal(null, _)) => Literal.create(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) case e @ Count(expr) if !expr.nullable => Count(Literal(1)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 4574934d91..cd54d04814 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -160,7 +160,7 @@ object PartialAggregation { // resolving struct field accesses, because `GetField` is not a `NamedExpression`. // (Should we just turn `GetField` into a `NamedExpression`?) namedGroupingExpressions - .get(e.transform { case Alias(g: GetField, _) => g }) + .get(e.transform { case Alias(g: ExtractValue, _) => g }) .map(_.toAttribute) .getOrElse(e) }).asInstanceOf[Seq[NamedExpression]] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index ae4620a4e5..dbb12d56f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -209,7 +209,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // For example, consider "a.b.c", where "a" is resolved to an existing attribute. // Then this will add GetField("c", GetField("b", a)), and alias // the final expression as "c". - val fieldExprs = nestedFields.foldLeft(a: Expression)(GetField(_, _, resolver)) + val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) => + ExtractValue(expr, Literal(fieldName), resolver)) val aliasName = nestedFields.last Some(Alias(fieldExprs, aliasName)()) } catch { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 88d36d153c..04fd261d16 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -26,7 +26,7 @@ import org.scalatest.FunSuite import org.scalatest.Matchers._ import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField +import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.mathfuncs._ import org.apache.spark.sql.types._ @@ -880,7 +880,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { val row = create_row( "^Ba*n", // 0 null.asInstanceOf[UTF8String], // 1 - create_row("aa", "bb"), // 2 + create_row("aa", "bb"), // 2 Map("aa"->"bb"), // 3 Seq("aa", "bb") // 4 ) @@ -891,54 +891,79 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { val typeMap = MapType(StringType, StringType) val typeArray = ArrayType(StringType) - checkEvaluation(GetItem(BoundReference(3, typeMap, true), + checkEvaluation(GetMapValue(BoundReference(3, typeMap, true), Literal("aa")), "bb", row) - checkEvaluation(GetItem(Literal.create(null, typeMap), Literal("aa")), null, row) + checkEvaluation(GetMapValue(Literal.create(null, typeMap), Literal("aa")), null, row) checkEvaluation( - GetItem(Literal.create(null, typeMap), Literal.create(null, StringType)), null, row) - checkEvaluation(GetItem(BoundReference(3, typeMap, true), + GetMapValue(Literal.create(null, typeMap), Literal.create(null, StringType)), null, row) + checkEvaluation(GetMapValue(BoundReference(3, typeMap, true), Literal.create(null, StringType)), null, row) - checkEvaluation(GetItem(BoundReference(4, typeArray, true), + checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true), Literal(1)), "bb", row) - checkEvaluation(GetItem(Literal.create(null, typeArray), Literal(1)), null, row) + checkEvaluation(GetArrayItem(Literal.create(null, typeArray), Literal(1)), null, row) checkEvaluation( - GetItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null, row) - checkEvaluation(GetItem(BoundReference(4, typeArray, true), + GetArrayItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null, row) + checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true), Literal.create(null, IntegerType)), null, row) - def quickBuildGetField(expr: Expression, fieldName: String): StructGetField = { + def getStructField(expr: Expression, fieldName: String): ExtractValue = { expr.dataType match { case StructType(fields) => val field = fields.find(_.name == fieldName).get - StructGetField(expr, field, fields.indexOf(field)) + GetStructField(expr, field, fields.indexOf(field)) } } - def quickResolve(u: UnresolvedGetField): StructGetField = { - quickBuildGetField(u.child, u.fieldName) + def quickResolve(u: UnresolvedExtractValue): ExtractValue = { + ExtractValue(u.child, u.extraction, _ == _) } - checkEvaluation(quickBuildGetField(BoundReference(2, typeS, nullable = true), "a"), "aa", row) - checkEvaluation(quickBuildGetField(Literal.create(null, typeS), "a"), null, row) + checkEvaluation(getStructField(BoundReference(2, typeS, nullable = true), "a"), "aa", row) + checkEvaluation(getStructField(Literal.create(null, typeS), "a"), null, row) val typeS_notNullable = StructType( StructField("a", StringType, nullable = false) :: StructField("b", StringType, nullable = false) :: Nil ) - assert(quickBuildGetField(BoundReference(2,typeS, nullable = true), "a").nullable === true) - assert(quickBuildGetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable + assert(getStructField(BoundReference(2,typeS, nullable = true), "a").nullable === true) + assert(getStructField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable === false) - assert(quickBuildGetField(Literal.create(null, typeS), "a").nullable === true) - assert(quickBuildGetField(Literal.create(null, typeS_notNullable), "a").nullable === true) + assert(getStructField(Literal.create(null, typeS), "a").nullable === true) + assert(getStructField(Literal.create(null, typeS_notNullable), "a").nullable === true) - checkEvaluation('c.map(typeMap).at(3).getItem("aa"), "bb", row) - checkEvaluation('c.array(typeArray.elementType).at(4).getItem(1), "bb", row) + checkEvaluation(quickResolve('c.map(typeMap).at(3).getItem("aa")), "bb", row) + checkEvaluation(quickResolve('c.array(typeArray.elementType).at(4).getItem(1)), "bb", row) checkEvaluation(quickResolve('c.struct(typeS).at(2).getField("a")), "aa", row) } + test("error message of ExtractValue") { + val structType = StructType(StructField("a", StringType, true) :: Nil) + val arrayStructType = ArrayType(structType) + val arrayType = ArrayType(StringType) + val otherType = StringType + + def checkErrorMessage( + childDataType: DataType, + fieldDataType: DataType, + errorMesage: String): Unit = { + val e = intercept[org.apache.spark.sql.AnalysisException] { + ExtractValue( + Literal.create(null, childDataType), + Literal.create(null, fieldDataType), + _ == _) + } + assert(e.getMessage().contains(errorMesage)) + } + + checkErrorMessage(structType, IntegerType, "Field name should be String Literal") + checkErrorMessage(arrayStructType, BooleanType, "Field name should be String Literal") + checkErrorMessage(arrayType, StringType, "Array index should be integral type") + checkErrorMessage(otherType, StringType, "Can't extract value from") + } + test("arithmetic") { val row = create_row(1, 2, 3, null) val c1 = 'a.int.at(0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 18f92150b0..6b7d9a85c3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, EliminateSubQueries} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, EliminateSubQueries} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.plans.PlanTest @@ -180,10 +180,10 @@ class ConstantFoldingSuite extends PlanTest { IsNull(Literal(null)) as 'c1, IsNotNull(Literal(null)) as 'c2, - GetItem(Literal.create(null, ArrayType(IntegerType)), 1) as 'c3, - GetItem( + UnresolvedExtractValue(Literal.create(null, ArrayType(IntegerType)), 1) as 'c3, + UnresolvedExtractValue( Literal.create(Seq(1), ArrayType(IntegerType)), Literal.create(null, IntegerType)) as 'c4, - UnresolvedGetField( + UnresolvedExtractValue( Literal.create(null, StructType(Seq(StructField("a", IntegerType, true)))), "a") as 'c5, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 8bbe11b412..e6e475bb82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.Logging import org.apache.spark.sql.functions.lit import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar, UnresolvedGetField} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar, UnresolvedExtractValue} import org.apache.spark.sql.types._ @@ -67,6 +67,19 @@ class Column(protected[sql] val expr: Expression) extends Logging { override def hashCode: Int = this.expr.hashCode + /** + * Extracts a value or values from a complex type. + * The following types of extraction are supported: + * - Given an Array, an integer ordinal can be used to retrieve a single value. + * - Given a Map, a key of the correct type can be used to retrieve an individual value. + * - Given a Struct, a string fieldName can be used to extract that field. + * - Given an Array of Structs, a string fieldName can be used to extract filed + * of every struct in that array, and return an Array of fields + * + * @group expr_ops + */ + def apply(field: Any): Column = UnresolvedExtractValue(expr, Literal(field)) + /** * Unary minus, i.e. negate the expression. * {{{ @@ -529,14 +542,14 @@ class Column(protected[sql] val expr: Expression) extends Logging { * * @group expr_ops */ - def getItem(key: Any): Column = GetItem(expr, Literal(key)) + def getItem(key: Any): Column = UnresolvedExtractValue(expr, Literal(key)) /** * An expression that gets a field by name in a [[StructType]]. * * @group expr_ops */ - def getField(fieldName: String): Column = UnresolvedGetField(expr, fieldName) + def getField(fieldName: String): Column = UnresolvedExtractValue(expr, Literal(fieldName)) /** * An expression that returns a substring. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1515e9b843..d2ca8dccae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -449,7 +449,7 @@ class DataFrameSuite extends QueryTest { testData.collect().map { case Row(key: Int, value: String) => Row(key, value, key + 1) }.toSeq) - assert(df.schema.map(_.name).toSeq === Seq("key", "value", "newCol")) + assert(df.schema.map(_.name) === Seq("key", "value", "newCol")) } test("replace column using withColumn") { @@ -484,7 +484,7 @@ class DataFrameSuite extends QueryTest { testData.collect().map { case Row(key: Int, value: String) => Row(key, value, key + 1) }.toSeq) - assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol")) + assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol")) } test("randomSplit") { @@ -593,4 +593,10 @@ class DataFrameSuite extends QueryTest { Row(new java.math.BigDecimal(2.0))) TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) } + + test("SPARK-7133: Implement struct, array, and map field accessor") { + assert(complexData.filter(complexData("a")(0) === 2).count() == 1) + assert(complexData.filter(complexData("m")("1") === 1).count() == 1) + assert(complexData.filter(complexData("s")("key") === 1).count() == 1) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index f30b196734..04d40bbb2b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1204,7 +1204,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C nodeToExpr(qualifier) match { case UnresolvedAttribute(qualifierName) => UnresolvedAttribute(qualifierName :+ cleanIdentifier(attr)) - case other => UnresolvedGetField(other, attr) + case other => UnresolvedExtractValue(other, Literal(attr)) } /* Stars (*) */ @@ -1329,7 +1329,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C /* Complex datatype manipulation */ case Token("[", child :: ordinal :: Nil) => - GetItem(nodeToExpr(child), nodeToExpr(ordinal)) + UnresolvedExtractValue(nodeToExpr(child), nodeToExpr(ordinal)) /* Other functions */ case Token("TOK_FUNCTION", Token(ARRAY(), Nil) :: children) => -- cgit v1.2.3