aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-04-16 17:33:57 -0700
committerMichael Armbrust <michael@databricks.com>2015-04-16 17:33:57 -0700
commit6183b5e2caedd074073d0f6cb6609a634e2f5194 (patch)
tree073a82a2ff33eea0a5d8faae03e313ee749198b6 /sql
parent5fe43433529346788e8c343d338a5b7dc169cf58 (diff)
downloadspark-6183b5e2caedd074073d0f6cb6609a634e2f5194.tar.gz
spark-6183b5e2caedd074073d0f6cb6609a634e2f5194.tar.bz2
spark-6183b5e2caedd074073d0f6cb6609a634e2f5194.zip
[SPARK-6911] [SQL] improve accessor for nested types
Support access columns by index in Python: ``` >>> df[df[0] > 3].collect() [Row(age=5, name=u'Bob')] ``` Access items in ArrayType or MapType ``` >>> df.select(df.l.getItem(0), df.d.getItem("key")).show() >>> df.select(df.l[0], df.d["key"]).show() ``` Access field in StructType ``` >>> df.select(df.r.getField("b")).show() >>> df.select(df.r.a).show() ``` Author: Davies Liu <davies@databricks.com> Closes #5513 from davies/access and squashes the following commits: e04d5a0 [Davies Liu] Update run-tests-jenkins 7ada9eb [Davies Liu] update timeout d125ac4 [Davies Liu] check column name, improve scala tests 6b62540 [Davies Liu] fix test db15b42 [Davies Liu] Merge branch 'master' of github.com:apache/spark into access 6c32e79 [Davies Liu] add scala tests 11f1df3 [Davies Liu] improve accessor for nested types
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/TestData.scala9
3 files changed, 14 insertions, 8 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 3cd7adf8ca..edb229c059 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -515,14 +515,15 @@ class Column(protected[sql] val expr: Expression) extends Logging {
def rlike(literal: String): Column = RLike(expr, lit(literal).expr)
/**
- * An expression that gets an item at position `ordinal` out of an array.
+ * An expression that gets an item at position `ordinal` out of an array,
+ * or gets a value by key `key` in a [[MapType]].
*
* @group expr_ops
*/
- def getItem(ordinal: Int): Column = GetItem(expr, Literal(ordinal))
+ def getItem(key: Any): Column = GetItem(expr, Literal(key))
/**
- * An expression that gets a field by name in a [[StructField]].
+ * An expression that gets a field by name in a [[StructType]].
*
* @group expr_ops
*/
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index b26e22f622..34b2cb054a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -86,6 +86,12 @@ class DataFrameSuite extends QueryTest {
TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString)
}
+ test("access complex data") {
+ assert(complexData.filter(complexData("a").getItem(0) === 2).count() == 1)
+ assert(complexData.filter(complexData("m").getItem("1") === 1).count() == 1)
+ assert(complexData.filter(complexData("s").getField("key") === 1).count() == 1)
+ }
+
test("table scan") {
checkAnswer(
testData,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index 637f59b2e6..225b51bd73 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -20,9 +20,8 @@ package org.apache.spark.sql
import java.sql.Timestamp
import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.functions._
-import org.apache.spark.sql.test._
import org.apache.spark.sql.test.TestSQLContext.implicits._
+import org.apache.spark.sql.test._
case class TestData(key: Int, value: String)
@@ -199,11 +198,11 @@ object TestData {
Salary(1, 1000.0) :: Nil).toDF()
salary.registerTempTable("salary")
- case class ComplexData(m: Map[Int, String], s: TestData, a: Seq[Int], b: Boolean)
+ case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean)
val complexData =
TestSQLContext.sparkContext.parallelize(
- ComplexData(Map(1 -> "1"), TestData(1, "1"), Seq(1), true)
- :: ComplexData(Map(2 -> "2"), TestData(2, "2"), Seq(2), false)
+ ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1), true)
+ :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2), false)
:: Nil).toDF()
complexData.registerTempTable("complexData")
}