diff options
author | Wenchen Fan <cloud0fan@outlook.com> | 2015-05-08 11:49:38 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-05-08 11:49:38 -0700 |
commit | 2d05f325dc3c70349bd17ed399897f22d967c687 (patch) | |
tree | 80c39fe01722882e02c9a4e0be9c35a74c082b78 /sql/core/src | |
parent | a1ec08f7edc8d956afcfbb92d10b26b7619486e8 (diff) | |
download | spark-2d05f325dc3c70349bd17ed399897f22d967c687.tar.gz spark-2d05f325dc3c70349bd17ed399897f22d967c687.tar.bz2 spark-2d05f325dc3c70349bd17ed399897f22d967c687.zip |
[SPARK-7133] [SQL] Implement struct, array, and map field accessor
It's the first step: generalize UnresolvedGetField to support all map, struct, and array
TODO: add `apply` in Scala and `__getitem__` in Python, and unify the `getItem` and `getField` methods to one single API(or should we keep them for compatibility?).
Author: Wenchen Fan <cloud0fan@outlook.com>
Closes #5744 from cloud-fan/generalize and squashes the following commits:
715c589 [Wenchen Fan] address comments
7ea5b31 [Wenchen Fan] fix python test
4f0833a [Wenchen Fan] add python test
f515d69 [Wenchen Fan] add apply method and test cases
8df6199 [Wenchen Fan] fix python test
239730c [Wenchen Fan] fix test compile
2a70526 [Wenchen Fan] use _bin_op in dataframe.py
6bf72bc [Wenchen Fan] address comments
3f880c3 [Wenchen Fan] add java doc
ab35ab5 [Wenchen Fan] fix python test
b5961a9 [Wenchen Fan] fix style
c9d85f5 [Wenchen Fan] generalize UnresolvedGetField to support all map, struct, and array
Diffstat (limited to 'sql/core/src')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/Column.scala | 19 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 10 |
2 files changed, 24 insertions, 5 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 8bbe11b412..e6e475bb82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.Logging import org.apache.spark.sql.functions.lit import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar, UnresolvedGetField} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar, UnresolvedExtractValue} import org.apache.spark.sql.types._ @@ -68,6 +68,19 @@ class Column(protected[sql] val expr: Expression) extends Logging { override def hashCode: Int = this.expr.hashCode /** + * Extracts a value or values from a complex type. + * The following types of extraction are supported: + * - Given an Array, an integer ordinal can be used to retrieve a single value. + * - Given a Map, a key of the correct type can be used to retrieve an individual value. + * - Given a Struct, a string fieldName can be used to extract that field. + * - Given an Array of Structs, a string fieldName can be used to extract filed + * of every struct in that array, and return an Array of fields + * + * @group expr_ops + */ + def apply(field: Any): Column = UnresolvedExtractValue(expr, Literal(field)) + + /** * Unary minus, i.e. negate the expression. * {{{ * // Scala: select the amount column and negates all values. @@ -529,14 +542,14 @@ class Column(protected[sql] val expr: Expression) extends Logging { * * @group expr_ops */ - def getItem(key: Any): Column = GetItem(expr, Literal(key)) + def getItem(key: Any): Column = UnresolvedExtractValue(expr, Literal(key)) /** * An expression that gets a field by name in a [[StructType]]. * * @group expr_ops */ - def getField(fieldName: String): Column = UnresolvedGetField(expr, fieldName) + def getField(fieldName: String): Column = UnresolvedExtractValue(expr, Literal(fieldName)) /** * An expression that returns a substring. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1515e9b843..d2ca8dccae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -449,7 +449,7 @@ class DataFrameSuite extends QueryTest { testData.collect().map { case Row(key: Int, value: String) => Row(key, value, key + 1) }.toSeq) - assert(df.schema.map(_.name).toSeq === Seq("key", "value", "newCol")) + assert(df.schema.map(_.name) === Seq("key", "value", "newCol")) } test("replace column using withColumn") { @@ -484,7 +484,7 @@ class DataFrameSuite extends QueryTest { testData.collect().map { case Row(key: Int, value: String) => Row(key, value, key + 1) }.toSeq) - assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol")) + assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol")) } test("randomSplit") { @@ -593,4 +593,10 @@ class DataFrameSuite extends QueryTest { Row(new java.math.BigDecimal(2.0))) TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) } + + test("SPARK-7133: Implement struct, array, and map field accessor") { + assert(complexData.filter(complexData("a")(0) === 2).count() == 1) + assert(complexData.filter(complexData("m")("1") === 1).count() == 1) + assert(complexData.filter(complexData("s")("key") === 1).count() == 1) + } } |