aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorDongjoon Hyun <dongjoon@apache.org>2016-06-30 12:03:54 -0700
committerReynold Xin <rxin@databricks.com>2016-06-30 12:03:54 -0700
commit46395db80e3304e3f3a1ebdc8aadb8f2819b48b4 (patch)
tree88b3c5cc5e5241f0e2b687445a29f88a2aca2c6b /sql/core
parentfdf9f94f8c8861a00cd8415073f842b857c397f7 (diff)
downloadspark-46395db80e3304e3f3a1ebdc8aadb8f2819b48b4.tar.gz
spark-46395db80e3304e3f3a1ebdc8aadb8f2819b48b4.tar.bz2
spark-46395db80e3304e3f3a1ebdc8aadb8f2819b48b4.zip
[SPARK-16289][SQL] Implement posexplode table generating function
## What changes were proposed in this pull request? This PR implements `posexplode` table generating function. Currently, master branch raises the following exception for `map` argument. It's different from Hive. **Before** ```scala scala> sql("select posexplode(map('a', 1, 'b', 2))").show org.apache.spark.sql.AnalysisException: No handler for Hive UDF ... posexplode() takes an array as a parameter; line 1 pos 7 ``` **After** ```scala scala> sql("select posexplode(map('a', 1, 'b', 2))").show +---+---+-----+ |pos|key|value| +---+---+-----+ | 0| a| 1| | 1| b| 2| +---+---+-----+ ``` For `array` argument, `after` is the same with `before`. ``` scala> sql("select posexplode(array(1, 2, 3))").show +---+---+ |pos|col| +---+---+ | 0| 1| | 1| 2| | 2| 3| +---+---+ ``` ## How was this patch tested? Pass the Jenkins tests with newly added testcases. Author: Dongjoon Hyun <dongjoon@apache.org> Closes #13971 from dongjoon-hyun/SPARK-16289.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala60
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala92
4 files changed, 101 insertions, 60 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 9f35107e5b..a46d1949e9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -159,6 +159,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
// Leave an unaliased generator with an empty list of names since the analyzer will generate
// the correct defaults after the nested expression's type has been resolved.
case explode: Explode => MultiAlias(explode, Nil)
+ case explode: PosExplode => MultiAlias(explode, Nil)
case jt: JsonTuple => MultiAlias(jt, Nil)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index e8bd489be3..c8782df146 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2722,6 +2722,14 @@ object functions {
def explode(e: Column): Column = withExpr { Explode(e.expr) }
/**
+ * Creates a new row for each element with position in the given array or map column.
+ *
+ * @group collection_funcs
+ * @since 2.1.0
+ */
+ def posexplode(e: Column): Column = withExpr { PosExplode(e.expr) }
+
+ /**
* Extracts json object from a json string based on json path specified, and returns json string
* of the extracted json object. It will return null if the input json string is invalid.
*
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index a66c83dea0..a170fae577 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -122,66 +122,6 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
assert(newCol.expr.asInstanceOf[NamedExpression].metadata.getString("key") === "value")
}
- test("single explode") {
- val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
- checkAnswer(
- df.select(explode('intList)),
- Row(1) :: Row(2) :: Row(3) :: Nil)
- }
-
- test("explode and other columns") {
- val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
-
- checkAnswer(
- df.select($"a", explode('intList)),
- Row(1, 1) ::
- Row(1, 2) ::
- Row(1, 3) :: Nil)
-
- checkAnswer(
- df.select($"*", explode('intList)),
- Row(1, Seq(1, 2, 3), 1) ::
- Row(1, Seq(1, 2, 3), 2) ::
- Row(1, Seq(1, 2, 3), 3) :: Nil)
- }
-
- test("aliased explode") {
- val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
-
- checkAnswer(
- df.select(explode('intList).as('int)).select('int),
- Row(1) :: Row(2) :: Row(3) :: Nil)
-
- checkAnswer(
- df.select(explode('intList).as('int)).select(sum('int)),
- Row(6) :: Nil)
- }
-
- test("explode on map") {
- val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
-
- checkAnswer(
- df.select(explode('map)),
- Row("a", "b"))
- }
-
- test("explode on map with aliases") {
- val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
-
- checkAnswer(
- df.select(explode('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"),
- Row("a", "b"))
- }
-
- test("self join explode") {
- val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
- val exploded = df.select(explode('intList).as('i))
-
- checkAnswer(
- exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")),
- Row(3) :: Nil)
- }
-
test("collect on column produced by a binary operator") {
val df = Seq((1, 2, 3)).toDF("a", "b", "c")
checkAnswer(df.select(df("a") + df("b")), Seq(Row(3)))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
new file mode 100644
index 0000000000..1f0ef34ec1
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
+
+class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
+
+ test("single explode") {
+ val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
+ checkAnswer(
+ df.select(explode('intList)),
+ Row(1) :: Row(2) :: Row(3) :: Nil)
+ }
+
+ test("single posexplode") {
+ val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
+ checkAnswer(
+ df.select(posexplode('intList)),
+ Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Nil)
+ }
+
+ test("explode and other columns") {
+ val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
+
+ checkAnswer(
+ df.select($"a", explode('intList)),
+ Row(1, 1) ::
+ Row(1, 2) ::
+ Row(1, 3) :: Nil)
+
+ checkAnswer(
+ df.select($"*", explode('intList)),
+ Row(1, Seq(1, 2, 3), 1) ::
+ Row(1, Seq(1, 2, 3), 2) ::
+ Row(1, Seq(1, 2, 3), 3) :: Nil)
+ }
+
+ test("aliased explode") {
+ val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
+
+ checkAnswer(
+ df.select(explode('intList).as('int)).select('int),
+ Row(1) :: Row(2) :: Row(3) :: Nil)
+
+ checkAnswer(
+ df.select(explode('intList).as('int)).select(sum('int)),
+ Row(6) :: Nil)
+ }
+
+ test("explode on map") {
+ val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
+
+ checkAnswer(
+ df.select(explode('map)),
+ Row("a", "b"))
+ }
+
+ test("explode on map with aliases") {
+ val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
+
+ checkAnswer(
+ df.select(explode('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"),
+ Row("a", "b"))
+ }
+
+ test("self join explode") {
+ val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
+ val exploded = df.select(explode('intList).as('i))
+
+ checkAnswer(
+ exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")),
+ Row(3) :: Nil)
+ }
+}