aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-05-08 11:49:38 -0700
committerMichael Armbrust <michael@databricks.com>2015-05-08 11:49:38 -0700
commit2d05f325dc3c70349bd17ed399897f22d967c687 (patch)
tree80c39fe01722882e02c9a4e0be9c35a74c082b78 /sql/core/src
parenta1ec08f7edc8d956afcfbb92d10b26b7619486e8 (diff)
downloadspark-2d05f325dc3c70349bd17ed399897f22d967c687.tar.gz
spark-2d05f325dc3c70349bd17ed399897f22d967c687.tar.bz2
spark-2d05f325dc3c70349bd17ed399897f22d967c687.zip
[SPARK-7133] [SQL] Implement struct, array, and map field accessor
It's the first step: generalize UnresolvedGetField to support all map, struct, and array TODO: add `apply` in Scala and `__getitem__` in Python, and unify the `getItem` and `getField` methods to one single API(or should we keep them for compatibility?). Author: Wenchen Fan <cloud0fan@outlook.com> Closes #5744 from cloud-fan/generalize and squashes the following commits: 715c589 [Wenchen Fan] address comments 7ea5b31 [Wenchen Fan] fix python test 4f0833a [Wenchen Fan] add python test f515d69 [Wenchen Fan] add apply method and test cases 8df6199 [Wenchen Fan] fix python test 239730c [Wenchen Fan] fix test compile 2a70526 [Wenchen Fan] use _bin_op in dataframe.py 6bf72bc [Wenchen Fan] address comments 3f880c3 [Wenchen Fan] add java doc ab35ab5 [Wenchen Fan] fix python test b5961a9 [Wenchen Fan] fix style c9d85f5 [Wenchen Fan] generalize UnresolvedGetField to support all map, struct, and array
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala19
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala10
2 files changed, 24 insertions, 5 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 8bbe11b412..e6e475bb82 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -23,7 +23,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar, UnresolvedGetField}
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar, UnresolvedExtractValue}
import org.apache.spark.sql.types._
@@ -68,6 +68,19 @@ class Column(protected[sql] val expr: Expression) extends Logging {
override def hashCode: Int = this.expr.hashCode
/**
+ * Extracts a value or values from a complex type.
+ * The following types of extraction are supported:
+ * - Given an Array, an integer ordinal can be used to retrieve a single value.
+ * - Given a Map, a key of the correct type can be used to retrieve an individual value.
+ * - Given a Struct, a string fieldName can be used to extract that field.
+ * - Given an Array of Structs, a string fieldName can be used to extract filed
+ * of every struct in that array, and return an Array of fields
+ *
+ * @group expr_ops
+ */
+ def apply(field: Any): Column = UnresolvedExtractValue(expr, Literal(field))
+
+ /**
* Unary minus, i.e. negate the expression.
* {{{
* // Scala: select the amount column and negates all values.
@@ -529,14 +542,14 @@ class Column(protected[sql] val expr: Expression) extends Logging {
*
* @group expr_ops
*/
- def getItem(key: Any): Column = GetItem(expr, Literal(key))
+ def getItem(key: Any): Column = UnresolvedExtractValue(expr, Literal(key))
/**
* An expression that gets a field by name in a [[StructType]].
*
* @group expr_ops
*/
- def getField(fieldName: String): Column = UnresolvedGetField(expr, fieldName)
+ def getField(fieldName: String): Column = UnresolvedExtractValue(expr, Literal(fieldName))
/**
* An expression that returns a substring.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 1515e9b843..d2ca8dccae 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -449,7 +449,7 @@ class DataFrameSuite extends QueryTest {
testData.collect().map { case Row(key: Int, value: String) =>
Row(key, value, key + 1)
}.toSeq)
- assert(df.schema.map(_.name).toSeq === Seq("key", "value", "newCol"))
+ assert(df.schema.map(_.name) === Seq("key", "value", "newCol"))
}
test("replace column using withColumn") {
@@ -484,7 +484,7 @@ class DataFrameSuite extends QueryTest {
testData.collect().map { case Row(key: Int, value: String) =>
Row(key, value, key + 1)
}.toSeq)
- assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol"))
+ assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol"))
}
test("randomSplit") {
@@ -593,4 +593,10 @@ class DataFrameSuite extends QueryTest {
Row(new java.math.BigDecimal(2.0)))
TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
}
+
+ test("SPARK-7133: Implement struct, array, and map field accessor") {
+ assert(complexData.filter(complexData("a")(0) === 2).count() == 1)
+ assert(complexData.filter(complexData("m")("1") === 1).count() == 1)
+ assert(complexData.filter(complexData("s")("key") === 1).count() == 1)
+ }
}