aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala19
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala10
2 files changed, 24 insertions, 5 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 8bbe11b412..e6e475bb82 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -23,7 +23,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar, UnresolvedGetField}
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar, UnresolvedExtractValue}
import org.apache.spark.sql.types._
@@ -68,6 +68,19 @@ class Column(protected[sql] val expr: Expression) extends Logging {
override def hashCode: Int = this.expr.hashCode
/**
+ * Extracts a value or values from a complex type.
+ * The following types of extraction are supported:
+ * - Given an Array, an integer ordinal can be used to retrieve a single value.
+ * - Given a Map, a key of the correct type can be used to retrieve an individual value.
+ * - Given a Struct, a string fieldName can be used to extract that field.
+ * - Given an Array of Structs, a string fieldName can be used to extract filed
+ * of every struct in that array, and return an Array of fields
+ *
+ * @group expr_ops
+ */
+ def apply(field: Any): Column = UnresolvedExtractValue(expr, Literal(field))
+
+ /**
* Unary minus, i.e. negate the expression.
* {{{
* // Scala: select the amount column and negates all values.
@@ -529,14 +542,14 @@ class Column(protected[sql] val expr: Expression) extends Logging {
*
* @group expr_ops
*/
- def getItem(key: Any): Column = GetItem(expr, Literal(key))
+ def getItem(key: Any): Column = UnresolvedExtractValue(expr, Literal(key))
/**
* An expression that gets a field by name in a [[StructType]].
*
* @group expr_ops
*/
- def getField(fieldName: String): Column = UnresolvedGetField(expr, fieldName)
+ def getField(fieldName: String): Column = UnresolvedExtractValue(expr, Literal(fieldName))
/**
* An expression that returns a substring.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 1515e9b843..d2ca8dccae 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -449,7 +449,7 @@ class DataFrameSuite extends QueryTest {
testData.collect().map { case Row(key: Int, value: String) =>
Row(key, value, key + 1)
}.toSeq)
- assert(df.schema.map(_.name).toSeq === Seq("key", "value", "newCol"))
+ assert(df.schema.map(_.name) === Seq("key", "value", "newCol"))
}
test("replace column using withColumn") {
@@ -484,7 +484,7 @@ class DataFrameSuite extends QueryTest {
testData.collect().map { case Row(key: Int, value: String) =>
Row(key, value, key + 1)
}.toSeq)
- assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol"))
+ assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol"))
}
test("randomSplit") {
@@ -593,4 +593,10 @@ class DataFrameSuite extends QueryTest {
Row(new java.math.BigDecimal(2.0)))
TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
}
+
+ test("SPARK-7133: Implement struct, array, and map field accessor") {
+ assert(complexData.filter(complexData("a")(0) === 2).count() == 1)
+ assert(complexData.filter(complexData("m")("1") === 1).count() == 1)
+ assert(complexData.filter(complexData("s")("key") === 1).count() == 1)
+ }
}