aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-05-08 11:49:38 -0700
committerMichael Armbrust <michael@databricks.com>2015-05-08 11:49:38 -0700
commit2d05f325dc3c70349bd17ed399897f22d967c687 (patch)
tree80c39fe01722882e02c9a4e0be9c35a74c082b78
parenta1ec08f7edc8d956afcfbb92d10b26b7619486e8 (diff)
downloadspark-2d05f325dc3c70349bd17ed399897f22d967c687.tar.gz
spark-2d05f325dc3c70349bd17ed399897f22d967c687.tar.bz2
spark-2d05f325dc3c70349bd17ed399897f22d967c687.zip
[SPARK-7133] [SQL] Implement struct, array, and map field accessor
It's the first step: generalize UnresolvedGetField to support all map, struct, and array TODO: add `apply` in Scala and `__getitem__` in Python, and unify the `getItem` and `getField` methods to one single API(or should we keep them for compatibility?). Author: Wenchen Fan <cloud0fan@outlook.com> Closes #5744 from cloud-fan/generalize and squashes the following commits: 715c589 [Wenchen Fan] address comments 7ea5b31 [Wenchen Fan] fix python test 4f0833a [Wenchen Fan] add python test f515d69 [Wenchen Fan] add apply method and test cases 8df6199 [Wenchen Fan] fix python test 239730c [Wenchen Fan] fix test compile 2a70526 [Wenchen Fan] use _bin_op in dataframe.py 6bf72bc [Wenchen Fan] address comments 3f880c3 [Wenchen Fan] add java doc ab35ab5 [Wenchen Fan] fix python test b5961a9 [Wenchen Fan] fix style c9d85f5 [Wenchen Fan] generalize UnresolvedGetField to support all map, struct, and array
-rw-r--r--python/pyspark/sql/dataframe.py24
-rw-r--r--python/pyspark/sql/tests.py7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala206
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala131
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala69
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala19
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala10
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala4
16 files changed, 327 insertions, 191 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index cee804f5cc..a9697999e8 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1275,7 +1275,7 @@ class Column(object):
# container operators
__contains__ = _bin_op("contains")
- __getitem__ = _bin_op("getItem")
+ __getitem__ = _bin_op("apply")
# bitwise operators
bitwiseOR = _bin_op("bitwiseOR")
@@ -1308,19 +1308,19 @@ class Column(object):
>>> from pyspark.sql import Row
>>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF()
>>> df.select(df.r.getField("b")).show()
- +---+
- |r.b|
- +---+
- | b|
- +---+
+ +----+
+ |r[b]|
+ +----+
+ | b|
+ +----+
>>> df.select(df.r.a).show()
- +---+
- |r.a|
- +---+
- | 1|
- +---+
+ +----+
+ |r[a]|
+ +----+
+ | 1|
+ +----+
"""
- return Column(self._jc.getField(name))
+ return self[name]
def __getattr__(self, item):
if item.startswith("__"):
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 45dfedce22..7e63f4d646 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -519,6 +519,13 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual("v", df.select(df.d["k"]).first()[0])
self.assertEqual("v", df.select(df.d.getItem("k")).first()[0])
+ def test_field_accessor(self):
+ df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
+ self.assertEqual(1, df.select(df.l[0]).first()[0])
+ self.assertEqual(1, df.select(df.r["a"]).first()[0])
+ self.assertEqual("b", df.select(df.r["b"]).first()[0])
+ self.assertEqual("v", df.select(df.d["k"]).first()[0])
+
def test_infer_long_type(self):
longrow = [Row(f1='a', f2=100000000000000)]
df = self.sc.parallelize(longrow).toDF()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index b06bfb2ce8..fc36b9f1f2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -375,9 +375,9 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
protected lazy val primary: PackratParser[Expression] =
( literal
| expression ~ ("[" ~> expression <~ "]") ^^
- { case base ~ ordinal => GetItem(base, ordinal) }
+ { case base ~ ordinal => UnresolvedExtractValue(base, ordinal) }
| (expression <~ ".") ~ ident ^^
- { case base ~ fieldName => UnresolvedGetField(base, fieldName) }
+ { case base ~ fieldName => UnresolvedExtractValue(base, Literal(fieldName)) }
| cast
| "(" ~> expression <~ ")"
| function
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index bb7913e186..ecbac57ea4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -348,8 +348,8 @@ class Analyzer(
withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
logDebug(s"Resolving $u to $result")
result
- case UnresolvedGetField(child, fieldName) if child.resolved =>
- GetField(child, fieldName, resolver)
+ case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
+ ExtractValue(child, fieldExpr, resolver)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index eb736ac329..2999c2ef3e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -184,7 +184,17 @@ case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star {
override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")")
}
-case class UnresolvedGetField(child: Expression, fieldName: String) extends UnaryExpression {
+/**
+ * Extracts a value or values from an Expression
+ *
+ * @param child The expression to extract value from,
+ * can be Map, Array, Struct or array of Structs.
+ * @param extraction The expression to describe the extraction,
+ * can be key of Map, index of Array, field name of Struct.
+ */
+case class UnresolvedExtractValue(child: Expression, extraction: Expression)
+ extends UnaryExpression {
+
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
@@ -193,5 +203,5 @@ case class UnresolvedGetField(child: Expression, fieldName: String) extends Unar
override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
- override def toString: String = s"$child.$fieldName"
+ override def toString: String = s"$child[$extraction]"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index fa6cc7a1a3..4c0d70203c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
import scala.language.implicitConversions
import scala.reflect.runtime.universe.{TypeTag, typeTag}
-import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedGetField, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedExtractValue, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
@@ -100,8 +100,9 @@ package object dsl {
def isNull: Predicate = IsNull(expr)
def isNotNull: Predicate = IsNotNull(expr)
- def getItem(ordinal: Expression): Expression = GetItem(expr, ordinal)
- def getField(fieldName: String): UnresolvedGetField = UnresolvedGetField(expr, fieldName)
+ def getItem(ordinal: Expression): UnresolvedExtractValue = UnresolvedExtractValue(expr, ordinal)
+ def getField(fieldName: String): UnresolvedExtractValue =
+ UnresolvedExtractValue(expr, Literal(fieldName))
def cast(to: DataType): Expression = Cast(expr, to)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
new file mode 100644
index 0000000000..e05926cbfe
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import scala.collection.Map
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.types._
+
+object ExtractValue {
+ /**
+ * Returns the resolved `ExtractValue`. It will return one kind of concrete `ExtractValue`,
+ * depend on the type of `child` and `extraction`.
+ *
+ * `child` | `extraction` | concrete `ExtractValue`
+ * ----------------------------------------------------------------
+ * Struct | Literal String | GetStructField
+ * Array[Struct] | Literal String | GetArrayStructFields
+ * Array | Integral type | GetArrayItem
+ * Map | Any type | GetMapValue
+ */
+ def apply(
+ child: Expression,
+ extraction: Expression,
+ resolver: Resolver): ExtractValue = {
+
+ (child.dataType, extraction) match {
+ case (StructType(fields), Literal(fieldName, StringType)) =>
+ val ordinal = findField(fields, fieldName.toString, resolver)
+ GetStructField(child, fields(ordinal), ordinal)
+ case (ArrayType(StructType(fields), containsNull), Literal(fieldName, StringType)) =>
+ val ordinal = findField(fields, fieldName.toString, resolver)
+ GetArrayStructFields(child, fields(ordinal), ordinal, containsNull)
+ case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] =>
+ GetArrayItem(child, extraction)
+ case (_: MapType, _) =>
+ GetMapValue(child, extraction)
+ case (otherType, _) =>
+ val errorMsg = otherType match {
+ case StructType(_) | ArrayType(StructType(_), _) =>
+ s"Field name should be String Literal, but it's $extraction"
+ case _: ArrayType =>
+ s"Array index should be integral type, but it's ${extraction.dataType}"
+ case other =>
+ s"Can't extract value from $child"
+ }
+ throw new AnalysisException(errorMsg)
+ }
+ }
+
+ def unapply(g: ExtractValue): Option[(Expression, Expression)] = {
+ g match {
+ case o: ExtractValueWithOrdinal => Some((o.child, o.ordinal))
+ case _ => Some((g.child, null))
+ }
+ }
+
+ /**
+ * Find the ordinal of StructField, report error if no desired field or over one
+ * desired fields are found.
+ */
+ private def findField(fields: Array[StructField], fieldName: String, resolver: Resolver): Int = {
+ val checkField = (f: StructField) => resolver(f.name, fieldName)
+ val ordinal = fields.indexWhere(checkField)
+ if (ordinal == -1) {
+ throw new AnalysisException(
+ s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
+ } else if (fields.indexWhere(checkField, ordinal + 1) != -1) {
+ throw new AnalysisException(
+ s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}")
+ } else {
+ ordinal
+ }
+ }
+}
+
+trait ExtractValue extends UnaryExpression {
+ self: Product =>
+
+ type EvaluatedType = Any
+}
+
+/**
+ * Returns the value of fields in the Struct `child`.
+ */
+case class GetStructField(child: Expression, field: StructField, ordinal: Int)
+ extends ExtractValue {
+
+ override def dataType: DataType = field.dataType
+ override def nullable: Boolean = child.nullable || field.nullable
+ override def foldable: Boolean = child.foldable
+ override def toString: String = s"$child.${field.name}"
+
+ override def eval(input: Row): Any = {
+ val baseValue = child.eval(input).asInstanceOf[Row]
+ if (baseValue == null) null else baseValue(ordinal)
+ }
+}
+
+/**
+ * Returns the array of value of fields in the Array of Struct `child`.
+ */
+case class GetArrayStructFields(
+ child: Expression,
+ field: StructField,
+ ordinal: Int,
+ containsNull: Boolean) extends ExtractValue {
+
+ override def dataType: DataType = ArrayType(field.dataType, containsNull)
+ override def nullable: Boolean = child.nullable
+ override def foldable: Boolean = child.foldable
+ override def toString: String = s"$child.${field.name}"
+
+ override def eval(input: Row): Any = {
+ val baseValue = child.eval(input).asInstanceOf[Seq[Row]]
+ if (baseValue == null) null else {
+ baseValue.map { row =>
+ if (row == null) null else row(ordinal)
+ }
+ }
+ }
+}
+
+abstract class ExtractValueWithOrdinal extends ExtractValue {
+ self: Product =>
+
+ def ordinal: Expression
+
+ /** `Null` is returned for invalid ordinals. */
+ override def nullable: Boolean = true
+ override def foldable: Boolean = child.foldable && ordinal.foldable
+ override def toString: String = s"$child[$ordinal]"
+ override def children: Seq[Expression] = child :: ordinal :: Nil
+
+ override def eval(input: Row): Any = {
+ val value = child.eval(input)
+ if (value == null) {
+ null
+ } else {
+ val o = ordinal.eval(input)
+ if (o == null) {
+ null
+ } else {
+ evalNotNull(value, o)
+ }
+ }
+ }
+
+ protected def evalNotNull(value: Any, ordinal: Any): Any
+}
+
+/**
+ * Returns the field at `ordinal` in the Array `child`
+ */
+case class GetArrayItem(child: Expression, ordinal: Expression)
+ extends ExtractValueWithOrdinal {
+
+ override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType
+
+ override lazy val resolved = childrenResolved &&
+ child.dataType.isInstanceOf[ArrayType] && ordinal.dataType.isInstanceOf[IntegralType]
+
+ protected def evalNotNull(value: Any, ordinal: Any) = {
+ // TODO: consider using Array[_] for ArrayType child to avoid
+ // boxing of primitives
+ val baseValue = value.asInstanceOf[Seq[_]]
+ val index = ordinal.asInstanceOf[Int]
+ if (index >= baseValue.size || index < 0) {
+ null
+ } else {
+ baseValue(index)
+ }
+ }
+}
+
+/**
+ * Returns the value of key `ordinal` in Map `child`
+ */
+case class GetMapValue(child: Expression, ordinal: Expression)
+ extends ExtractValueWithOrdinal {
+
+ override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType
+
+ override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[MapType]
+
+ protected def evalNotNull(value: Any, ordinal: Any) = {
+ val baseValue = value.asInstanceOf[Map[Any, _]]
+ baseValue.get(ordinal).orNull
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
index fc1f696559..956a2429b0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
@@ -17,139 +17,8 @@
package org.apache.spark.sql.catalyst.expressions
-import scala.collection.Map
-
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.types._
-/**
- * Returns the item at `ordinal` in the Array `child` or the Key `ordinal` in Map `child`.
- */
-case class GetItem(child: Expression, ordinal: Expression) extends Expression {
- type EvaluatedType = Any
-
- val children: Seq[Expression] = child :: ordinal :: Nil
- /** `Null` is returned for invalid ordinals. */
- override def nullable: Boolean = true
- override def foldable: Boolean = child.foldable && ordinal.foldable
-
- override def dataType: DataType = child.dataType match {
- case ArrayType(dt, _) => dt
- case MapType(_, vt, _) => vt
- }
- override lazy val resolved =
- childrenResolved &&
- (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType])
-
- override def toString: String = s"$child[$ordinal]"
-
- override def eval(input: Row): Any = {
- val value = child.eval(input)
- if (value == null) {
- null
- } else {
- val key = ordinal.eval(input)
- if (key == null) {
- null
- } else {
- if (child.dataType.isInstanceOf[ArrayType]) {
- // TODO: consider using Array[_] for ArrayType child to avoid
- // boxing of primitives
- val baseValue = value.asInstanceOf[Seq[_]]
- val o = key.asInstanceOf[Int]
- if (o >= baseValue.size || o < 0) {
- null
- } else {
- baseValue(o)
- }
- } else {
- val baseValue = value.asInstanceOf[Map[Any, _]]
- baseValue.get(key).orNull
- }
- }
- }
- }
-}
-
-
-trait GetField extends UnaryExpression {
- self: Product =>
-
- type EvaluatedType = Any
- override def foldable: Boolean = child.foldable
- override def toString: String = s"$child.${field.name}"
-
- def field: StructField
-}
-
-object GetField {
- /**
- * Returns the resolved `GetField`, and report error if no desired field or over one
- * desired fields are found.
- */
- def apply(
- expr: Expression,
- fieldName: String,
- resolver: Resolver): GetField = {
- def findField(fields: Array[StructField]): Int = {
- val checkField = (f: StructField) => resolver(f.name, fieldName)
- val ordinal = fields.indexWhere(checkField)
- if (ordinal == -1) {
- throw new AnalysisException(
- s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
- } else if (fields.indexWhere(checkField, ordinal + 1) != -1) {
- throw new AnalysisException(
- s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}")
- } else {
- ordinal
- }
- }
- expr.dataType match {
- case StructType(fields) =>
- val ordinal = findField(fields)
- StructGetField(expr, fields(ordinal), ordinal)
- case ArrayType(StructType(fields), containsNull) =>
- val ordinal = findField(fields)
- ArrayGetField(expr, fields(ordinal), ordinal, containsNull)
- case otherType =>
- throw new AnalysisException(s"GetField is not valid on fields of type $otherType")
- }
- }
-}
-
-/**
- * Returns the value of fields in the Struct `child`.
- */
-case class StructGetField(child: Expression, field: StructField, ordinal: Int) extends GetField {
-
- override def dataType: DataType = field.dataType
- override def nullable: Boolean = child.nullable || field.nullable
-
- override def eval(input: Row): Any = {
- val baseValue = child.eval(input).asInstanceOf[Row]
- if (baseValue == null) null else baseValue(ordinal)
- }
-}
-
-/**
- * Returns the array of value of fields in the Array of Struct `child`.
- */
-case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, containsNull: Boolean)
- extends GetField {
-
- override def dataType: DataType = ArrayType(field.dataType, containsNull)
- override def nullable: Boolean = child.nullable
-
- override def eval(input: Row): Any = {
- val baseValue = child.eval(input).asInstanceOf[Seq[Row]]
- if (baseValue == null) null else {
- baseValue.map { row =>
- if (row == null) null else row(ordinal)
- }
- }
- }
-}
/**
* Returns an Array containing the evaluation of all children expressions.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index e4a60f53d6..d7b2f203a6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -227,10 +227,8 @@ object NullPropagation extends Rule[LogicalPlan] {
case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType)
case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
- case e @ GetItem(Literal(null, _), _) => Literal.create(null, e.dataType)
- case e @ GetItem(_, Literal(null, _)) => Literal.create(null, e.dataType)
- case e @ StructGetField(Literal(null, _), _, _) => Literal.create(null, e.dataType)
- case e @ ArrayGetField(Literal(null, _), _, _, _) => Literal.create(null, e.dataType)
+ case e @ ExtractValue(Literal(null, _), _) => Literal.create(null, e.dataType)
+ case e @ ExtractValue(_, Literal(null, _)) => Literal.create(null, e.dataType)
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
case e @ Count(expr) if !expr.nullable => Count(Literal(1))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 4574934d91..cd54d04814 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -160,7 +160,7 @@ object PartialAggregation {
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
// (Should we just turn `GetField` into a `NamedExpression`?)
namedGroupingExpressions
- .get(e.transform { case Alias(g: GetField, _) => g })
+ .get(e.transform { case Alias(g: ExtractValue, _) => g })
.map(_.toAttribute)
.getOrElse(e)
}).asInstanceOf[Seq[NamedExpression]]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index ae4620a4e5..dbb12d56f9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -209,7 +209,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
// For example, consider "a.b.c", where "a" is resolved to an existing attribute.
// Then this will add GetField("c", GetField("b", a)), and alias
// the final expression as "c".
- val fieldExprs = nestedFields.foldLeft(a: Expression)(GetField(_, _, resolver))
+ val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) =>
+ ExtractValue(expr, Literal(fieldName), resolver))
val aliasName = nestedFields.last
Some(Alias(fieldExprs, aliasName)())
} catch {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 88d36d153c..04fd261d16 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -26,7 +26,7 @@ import org.scalatest.FunSuite
import org.scalatest.Matchers._
import org.apache.spark.sql.catalyst.CatalystTypeConverters
-import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField
+import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.mathfuncs._
import org.apache.spark.sql.types._
@@ -880,7 +880,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
val row = create_row(
"^Ba*n", // 0
null.asInstanceOf[UTF8String], // 1
- create_row("aa", "bb"), // 2
+ create_row("aa", "bb"), // 2
Map("aa"->"bb"), // 3
Seq("aa", "bb") // 4
)
@@ -891,54 +891,79 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
val typeMap = MapType(StringType, StringType)
val typeArray = ArrayType(StringType)
- checkEvaluation(GetItem(BoundReference(3, typeMap, true),
+ checkEvaluation(GetMapValue(BoundReference(3, typeMap, true),
Literal("aa")), "bb", row)
- checkEvaluation(GetItem(Literal.create(null, typeMap), Literal("aa")), null, row)
+ checkEvaluation(GetMapValue(Literal.create(null, typeMap), Literal("aa")), null, row)
checkEvaluation(
- GetItem(Literal.create(null, typeMap), Literal.create(null, StringType)), null, row)
- checkEvaluation(GetItem(BoundReference(3, typeMap, true),
+ GetMapValue(Literal.create(null, typeMap), Literal.create(null, StringType)), null, row)
+ checkEvaluation(GetMapValue(BoundReference(3, typeMap, true),
Literal.create(null, StringType)), null, row)
- checkEvaluation(GetItem(BoundReference(4, typeArray, true),
+ checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true),
Literal(1)), "bb", row)
- checkEvaluation(GetItem(Literal.create(null, typeArray), Literal(1)), null, row)
+ checkEvaluation(GetArrayItem(Literal.create(null, typeArray), Literal(1)), null, row)
checkEvaluation(
- GetItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null, row)
- checkEvaluation(GetItem(BoundReference(4, typeArray, true),
+ GetArrayItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null, row)
+ checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true),
Literal.create(null, IntegerType)), null, row)
- def quickBuildGetField(expr: Expression, fieldName: String): StructGetField = {
+ def getStructField(expr: Expression, fieldName: String): ExtractValue = {
expr.dataType match {
case StructType(fields) =>
val field = fields.find(_.name == fieldName).get
- StructGetField(expr, field, fields.indexOf(field))
+ GetStructField(expr, field, fields.indexOf(field))
}
}
- def quickResolve(u: UnresolvedGetField): StructGetField = {
- quickBuildGetField(u.child, u.fieldName)
+ def quickResolve(u: UnresolvedExtractValue): ExtractValue = {
+ ExtractValue(u.child, u.extraction, _ == _)
}
- checkEvaluation(quickBuildGetField(BoundReference(2, typeS, nullable = true), "a"), "aa", row)
- checkEvaluation(quickBuildGetField(Literal.create(null, typeS), "a"), null, row)
+ checkEvaluation(getStructField(BoundReference(2, typeS, nullable = true), "a"), "aa", row)
+ checkEvaluation(getStructField(Literal.create(null, typeS), "a"), null, row)
val typeS_notNullable = StructType(
StructField("a", StringType, nullable = false)
:: StructField("b", StringType, nullable = false) :: Nil
)
- assert(quickBuildGetField(BoundReference(2,typeS, nullable = true), "a").nullable === true)
- assert(quickBuildGetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable
+ assert(getStructField(BoundReference(2,typeS, nullable = true), "a").nullable === true)
+ assert(getStructField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable
=== false)
- assert(quickBuildGetField(Literal.create(null, typeS), "a").nullable === true)
- assert(quickBuildGetField(Literal.create(null, typeS_notNullable), "a").nullable === true)
+ assert(getStructField(Literal.create(null, typeS), "a").nullable === true)
+ assert(getStructField(Literal.create(null, typeS_notNullable), "a").nullable === true)
- checkEvaluation('c.map(typeMap).at(3).getItem("aa"), "bb", row)
- checkEvaluation('c.array(typeArray.elementType).at(4).getItem(1), "bb", row)
+ checkEvaluation(quickResolve('c.map(typeMap).at(3).getItem("aa")), "bb", row)
+ checkEvaluation(quickResolve('c.array(typeArray.elementType).at(4).getItem(1)), "bb", row)
checkEvaluation(quickResolve('c.struct(typeS).at(2).getField("a")), "aa", row)
}
+ test("error message of ExtractValue") {
+ val structType = StructType(StructField("a", StringType, true) :: Nil)
+ val arrayStructType = ArrayType(structType)
+ val arrayType = ArrayType(StringType)
+ val otherType = StringType
+
+ def checkErrorMessage(
+ childDataType: DataType,
+ fieldDataType: DataType,
+ errorMesage: String): Unit = {
+ val e = intercept[org.apache.spark.sql.AnalysisException] {
+ ExtractValue(
+ Literal.create(null, childDataType),
+ Literal.create(null, fieldDataType),
+ _ == _)
+ }
+ assert(e.getMessage().contains(errorMesage))
+ }
+
+ checkErrorMessage(structType, IntegerType, "Field name should be String Literal")
+ checkErrorMessage(arrayStructType, BooleanType, "Field name should be String Literal")
+ checkErrorMessage(arrayType, StringType, "Array index should be integral type")
+ checkErrorMessage(otherType, StringType, "Can't extract value from")
+ }
+
test("arithmetic") {
val row = create_row(1, 2, 3, null)
val c1 = 'a.int.at(0)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
index 18f92150b0..6b7d9a85c3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.optimizer
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, EliminateSubQueries}
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, EliminateSubQueries}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.PlanTest
@@ -180,10 +180,10 @@ class ConstantFoldingSuite extends PlanTest {
IsNull(Literal(null)) as 'c1,
IsNotNull(Literal(null)) as 'c2,
- GetItem(Literal.create(null, ArrayType(IntegerType)), 1) as 'c3,
- GetItem(
+ UnresolvedExtractValue(Literal.create(null, ArrayType(IntegerType)), 1) as 'c3,
+ UnresolvedExtractValue(
Literal.create(Seq(1), ArrayType(IntegerType)), Literal.create(null, IntegerType)) as 'c4,
- UnresolvedGetField(
+ UnresolvedExtractValue(
Literal.create(null, StructType(Seq(StructField("a", IntegerType, true)))),
"a") as 'c5,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 8bbe11b412..e6e475bb82 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -23,7 +23,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar, UnresolvedGetField}
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar, UnresolvedExtractValue}
import org.apache.spark.sql.types._
@@ -68,6 +68,19 @@ class Column(protected[sql] val expr: Expression) extends Logging {
override def hashCode: Int = this.expr.hashCode
/**
+ * Extracts a value or values from a complex type.
+ * The following types of extraction are supported:
+ * - Given an Array, an integer ordinal can be used to retrieve a single value.
+ * - Given a Map, a key of the correct type can be used to retrieve an individual value.
+ * - Given a Struct, a string fieldName can be used to extract that field.
+ * - Given an Array of Structs, a string fieldName can be used to extract filed
+ * of every struct in that array, and return an Array of fields
+ *
+ * @group expr_ops
+ */
+ def apply(field: Any): Column = UnresolvedExtractValue(expr, Literal(field))
+
+ /**
* Unary minus, i.e. negate the expression.
* {{{
* // Scala: select the amount column and negates all values.
@@ -529,14 +542,14 @@ class Column(protected[sql] val expr: Expression) extends Logging {
*
* @group expr_ops
*/
- def getItem(key: Any): Column = GetItem(expr, Literal(key))
+ def getItem(key: Any): Column = UnresolvedExtractValue(expr, Literal(key))
/**
* An expression that gets a field by name in a [[StructType]].
*
* @group expr_ops
*/
- def getField(fieldName: String): Column = UnresolvedGetField(expr, fieldName)
+ def getField(fieldName: String): Column = UnresolvedExtractValue(expr, Literal(fieldName))
/**
* An expression that returns a substring.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 1515e9b843..d2ca8dccae 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -449,7 +449,7 @@ class DataFrameSuite extends QueryTest {
testData.collect().map { case Row(key: Int, value: String) =>
Row(key, value, key + 1)
}.toSeq)
- assert(df.schema.map(_.name).toSeq === Seq("key", "value", "newCol"))
+ assert(df.schema.map(_.name) === Seq("key", "value", "newCol"))
}
test("replace column using withColumn") {
@@ -484,7 +484,7 @@ class DataFrameSuite extends QueryTest {
testData.collect().map { case Row(key: Int, value: String) =>
Row(key, value, key + 1)
}.toSeq)
- assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol"))
+ assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol"))
}
test("randomSplit") {
@@ -593,4 +593,10 @@ class DataFrameSuite extends QueryTest {
Row(new java.math.BigDecimal(2.0)))
TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
}
+
+ test("SPARK-7133: Implement struct, array, and map field accessor") {
+ assert(complexData.filter(complexData("a")(0) === 2).count() == 1)
+ assert(complexData.filter(complexData("m")("1") === 1).count() == 1)
+ assert(complexData.filter(complexData("s")("key") === 1).count() == 1)
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index f30b196734..04d40bbb2b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -1204,7 +1204,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
nodeToExpr(qualifier) match {
case UnresolvedAttribute(qualifierName) =>
UnresolvedAttribute(qualifierName :+ cleanIdentifier(attr))
- case other => UnresolvedGetField(other, attr)
+ case other => UnresolvedExtractValue(other, Literal(attr))
}
/* Stars (*) */
@@ -1329,7 +1329,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
/* Complex datatype manipulation */
case Token("[", child :: ordinal :: Nil) =>
- GetItem(nodeToExpr(child), nodeToExpr(ordinal))
+ UnresolvedExtractValue(nodeToExpr(child), nodeToExpr(ordinal))
/* Other functions */
case Token("TOK_FUNCTION", Token(ARRAY(), Nil) :: children) =>