aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorDongjoon Hyun <dongjoon@apache.org>2016-06-30 12:03:54 -0700
committerReynold Xin <rxin@databricks.com>2016-06-30 12:03:54 -0700
commit46395db80e3304e3f3a1ebdc8aadb8f2819b48b4 (patch)
tree88b3c5cc5e5241f0e2b687445a29f88a2aca2c6b /sql/catalyst
parentfdf9f94f8c8861a00cd8415073f842b857c397f7 (diff)
downloadspark-46395db80e3304e3f3a1ebdc8aadb8f2819b48b4.tar.gz
spark-46395db80e3304e3f3a1ebdc8aadb8f2819b48b4.tar.bz2
spark-46395db80e3304e3f3a1ebdc8aadb8f2819b48b4.zip
[SPARK-16289][SQL] Implement posexplode table generating function
## What changes were proposed in this pull request? This PR implements `posexplode` table generating function. Currently, master branch raises the following exception for `map` argument. It's different from Hive. **Before** ```scala scala> sql("select posexplode(map('a', 1, 'b', 2))").show org.apache.spark.sql.AnalysisException: No handler for Hive UDF ... posexplode() takes an array as a parameter; line 1 pos 7 ``` **After** ```scala scala> sql("select posexplode(map('a', 1, 'b', 2))").show +---+---+-----+ |pos|key|value| +---+---+-----+ | 0| a| 1| | 1| b| 2| +---+---+-----+ ``` For `array` argument, `after` is the same with `before`. ``` scala> sql("select posexplode(array(1, 2, 3))").show +---+---+ |pos|col| +---+---+ | 0| 1| | 1| 2| | 2| 3| +---+---+ ``` ## How was this patch tested? Pass the Jenkins tests with newly added testcases. Author: Dongjoon Hyun <dongjoon@apache.org> Closes #13971 from dongjoon-hyun/SPARK-16289.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala66
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala71
4 files changed, 130 insertions, 10 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 3f9227a8ae..3fbdb2ab57 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -176,6 +176,7 @@ object FunctionRegistry {
expression[NullIf]("nullif"),
expression[Nvl]("nvl"),
expression[Nvl2]("nvl2"),
+ expression[PosExplode]("posexplode"),
expression[Rand]("rand"),
expression[Randn]("randn"),
expression[CreateStruct]("struct"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 12c35644e5..4e91cc5aec 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -94,13 +94,10 @@ case class UserDefinedGenerator(
}
/**
- * Given an input array produces a sequence of rows for each value in the array.
+ * A base class for Explode and PosExplode
*/
-// scalastyle:off line.size.limit
-@ExpressionDescription(
- usage = "_FUNC_(a) - Separates the elements of array a into multiple rows, or the elements of a map into multiple rows and columns.")
-// scalastyle:on line.size.limit
-case class Explode(child: Expression) extends UnaryExpression with Generator with CodegenFallback {
+abstract class ExplodeBase(child: Expression, position: Boolean)
+ extends UnaryExpression with Generator with CodegenFallback with Serializable {
override def children: Seq[Expression] = child :: Nil
@@ -115,9 +112,26 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit
// hive-compatible default alias for explode function ("col" for array, "key", "value" for map)
override def elementSchema: StructType = child.dataType match {
- case ArrayType(et, containsNull) => new StructType().add("col", et, containsNull)
+ case ArrayType(et, containsNull) =>
+ if (position) {
+ new StructType()
+ .add("pos", IntegerType, false)
+ .add("col", et, containsNull)
+ } else {
+ new StructType()
+ .add("col", et, containsNull)
+ }
case MapType(kt, vt, valueContainsNull) =>
- new StructType().add("key", kt, false).add("value", vt, valueContainsNull)
+ if (position) {
+ new StructType()
+ .add("pos", IntegerType, false)
+ .add("key", kt, false)
+ .add("value", vt, valueContainsNull)
+ } else {
+ new StructType()
+ .add("key", kt, false)
+ .add("value", vt, valueContainsNull)
+ }
}
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
@@ -129,7 +143,7 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit
} else {
val rows = new Array[InternalRow](inputArray.numElements())
inputArray.foreach(et, (i, e) => {
- rows(i) = InternalRow(e)
+ rows(i) = if (position) InternalRow(i, e) else InternalRow(e)
})
rows
}
@@ -141,7 +155,7 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit
val rows = new Array[InternalRow](inputMap.numElements())
var i = 0
inputMap.foreach(kt, vt, (k, v) => {
- rows(i) = InternalRow(k, v)
+ rows(i) = if (position) InternalRow(i, k, v) else InternalRow(k, v)
i += 1
})
rows
@@ -149,3 +163,35 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit
}
}
}
+
+/**
+ * Given an input array produces a sequence of rows for each value in the array.
+ *
+ * {{{
+ * SELECT explode(array(10,20)) ->
+ * 10
+ * 20
+ * }}}
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(a) - Separates the elements of array a into multiple rows, or the elements of map a into multiple rows and columns.",
+ extended = "> SELECT _FUNC_(array(10,20));\n 10\n 20")
+// scalastyle:on line.size.limit
+case class Explode(child: Expression) extends ExplodeBase(child, position = false)
+
+/**
+ * Given an input array produces a sequence of rows for each position and value in the array.
+ *
+ * {{{
+ * SELECT posexplode(array(10,20)) ->
+ * 0 10
+ * 1 20
+ * }}}
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(a) - Separates the elements of array a into multiple rows with positions, or the elements of a map into multiple rows and columns with positions.",
+ extended = "> SELECT _FUNC_(array(10,20));\n 0\t10\n 1\t20")
+// scalastyle:on line.size.limit
+case class PosExplode(child: Expression) extends ExplodeBase(child, position = true)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index 54436ea9a4..76e42d9afa 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -166,6 +166,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertError(new Murmur3Hash(Nil), "function hash requires at least one argument")
assertError(Explode('intField),
"input to function explode should be array or map type")
+ assertError(PosExplode('intField),
+ "input to function explode should be array or map type")
}
test("check types for CreateNamedStruct") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala
new file mode 100644
index 0000000000..2aba84141b
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.unsafe.types.UTF8String
+
+class GeneratorExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
+ private def checkTuple(actual: ExplodeBase, expected: Seq[InternalRow]): Unit = {
+ assert(actual.eval(null).toSeq === expected)
+ }
+
+ private final val int_array = Seq(1, 2, 3)
+ private final val str_array = Seq("a", "b", "c")
+
+ test("explode") {
+ val int_correct_answer = Seq(Seq(1), Seq(2), Seq(3))
+ val str_correct_answer = Seq(
+ Seq(UTF8String.fromString("a")),
+ Seq(UTF8String.fromString("b")),
+ Seq(UTF8String.fromString("c")))
+
+ checkTuple(
+ Explode(CreateArray(Seq.empty)),
+ Seq.empty)
+
+ checkTuple(
+ Explode(CreateArray(int_array.map(Literal(_)))),
+ int_correct_answer.map(InternalRow.fromSeq(_)))
+
+ checkTuple(
+ Explode(CreateArray(str_array.map(Literal(_)))),
+ str_correct_answer.map(InternalRow.fromSeq(_)))
+ }
+
+ test("posexplode") {
+ val int_correct_answer = Seq(Seq(0, 1), Seq(1, 2), Seq(2, 3))
+ val str_correct_answer = Seq(
+ Seq(0, UTF8String.fromString("a")),
+ Seq(1, UTF8String.fromString("b")),
+ Seq(2, UTF8String.fromString("c")))
+
+ checkTuple(
+ PosExplode(CreateArray(Seq.empty)),
+ Seq.empty)
+
+ checkTuple(
+ PosExplode(CreateArray(int_array.map(Literal(_)))),
+ int_correct_answer.map(InternalRow.fromSeq(_)))
+
+ checkTuple(
+ PosExplode(CreateArray(str_array.map(Literal(_)))),
+ str_correct_answer.map(InternalRow.fromSeq(_)))
+ }
+}