aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDongjoon Hyun <dongjoon@apache.org>2016-07-03 16:59:40 +0800
committerWenchen Fan <wenchen@databricks.com>2016-07-03 16:59:40 +0800
commit54b27c1797fcd32b3f3e9d44e1a149ae396a61e6 (patch)
tree9d00060aa80e539659bbb8d202bfcddf628c885c
parentea990f96930066c36055734d4f17eaf8e496eb3f (diff)
downloadspark-54b27c1797fcd32b3f3e9d44e1a149ae396a61e6.tar.gz
spark-54b27c1797fcd32b3f3e9d44e1a149ae396a61e6.tar.bz2
spark-54b27c1797fcd32b3f3e9d44e1a149ae396a61e6.zip
[SPARK-16278][SPARK-16279][SQL] Implement map_keys/map_values SQL functions
## What changes were proposed in this pull request? This PR adds `map_keys` and `map_values` SQL functions in order to remove Hive fallback. ## How was this patch tested? Pass the Jenkins tests including new testcases. Author: Dongjoon Hyun <dongjoon@apache.org> Closes #13967 from dongjoon-hyun/SPARK-16278.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala48
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala13
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala16
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala1
5 files changed, 79 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 26b0c30db4..e7f335f4fb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -171,6 +171,8 @@ object FunctionRegistry {
expression[IsNotNull]("isnotnull"),
expression[Least]("least"),
expression[CreateMap]("map"),
+ expression[MapKeys]("map_keys"),
+ expression[MapValues]("map_values"),
expression[CreateNamedStruct]("named_struct"),
expression[NaNvl]("nanvl"),
expression[NullIf]("nullif"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index c71cb73d65..2e8ea1107c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -44,6 +44,54 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType
}
/**
+ * Returns an unordered array containing the keys of the map.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(map) - Returns an unordered array containing the keys of the map.",
+ extended = " > SELECT _FUNC_(map(1, 'a', 2, 'b'));\n [1,2]")
+case class MapKeys(child: Expression)
+ extends UnaryExpression with ExpectsInputTypes {
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
+
+ override def dataType: DataType = ArrayType(child.dataType.asInstanceOf[MapType].keyType)
+
+ override def nullSafeEval(map: Any): Any = {
+ map.asInstanceOf[MapData].keyArray()
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, c => s"${ev.value} = ($c).keyArray();")
+ }
+
+ override def prettyName: String = "map_keys"
+}
+
+/**
+ * Returns an unordered array containing the values of the map.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(map) - Returns an unordered array containing the values of the map.",
+ extended = " > SELECT _FUNC_(map(1, 'a', 2, 'b'));\n [\"a\",\"b\"]")
+case class MapValues(child: Expression)
+ extends UnaryExpression with ExpectsInputTypes {
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
+
+ override def dataType: DataType = ArrayType(child.dataType.asInstanceOf[MapType].valueType)
+
+ override def nullSafeEval(map: Any): Any = {
+ map.asInstanceOf[MapData].valueArray()
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, c => s"${ev.value} = ($c).valueArray();")
+ }
+
+ override def prettyName: String = "map_values"
+}
+
+/**
* Sorts the input array in ascending / descending order according to the natural ordering of
* the array elements and returns it.
*/
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
index 1aae4678d6..a5f784fdcc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
@@ -44,6 +44,19 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
}
+ test("MapKeys/MapValues") {
+ val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType))
+ val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType))
+ val m2 = Literal.create(null, MapType(StringType, StringType))
+
+ checkEvaluation(MapKeys(m0), Seq("a", "b"))
+ checkEvaluation(MapValues(m0), Seq("1", "2"))
+ checkEvaluation(MapKeys(m1), Seq())
+ checkEvaluation(MapValues(m1), Seq())
+ checkEvaluation(MapKeys(m2), null)
+ checkEvaluation(MapValues(m2), null)
+ }
+
test("Sort Array") {
val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 73d77651a0..0f6c49e759 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -352,6 +352,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
)
}
+ test("map_keys/map_values function") {
+ val df = Seq(
+ (Map[Int, Int](1 -> 100, 2 -> 200), "x"),
+ (Map[Int, Int](), "y"),
+ (Map[Int, Int](1 -> 100, 2 -> 200, 3 -> 300), "z")
+ ).toDF("a", "b")
+ checkAnswer(
+ df.selectExpr("map_keys(a)"),
+ Seq(Row(Seq(1, 2)), Row(Seq.empty), Row(Seq(1, 2, 3)))
+ )
+ checkAnswer(
+ df.selectExpr("map_values(a)"),
+ Seq(Row(Seq(100, 200)), Row(Seq.empty), Row(Seq(100, 200, 300)))
+ )
+ }
+
test("array contains function") {
val df = Seq(
(Seq[Int](1, 2), "x"),
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index 1fffadbfca..53990b8e3b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -239,7 +239,6 @@ private[sql] class HiveSessionCatalog(
// str_to_map, windowingtablefunction.
private val hiveFunctions = Seq(
"hash", "java_method", "histogram_numeric",
- "map_keys", "map_values",
"parse_url", "percentile", "percentile_approx", "reflect", "sentences", "stack", "str_to_map",
"xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long",
"xpath_number", "xpath_short", "xpath_string",