aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDongjoon Hyun <dongjoon@apache.org>2016-07-04 01:57:45 +0800
committerWenchen Fan <wenchen@databricks.com>2016-07-04 01:57:45 +0800
commit88134e736829f5f93a82879c08cb191f175ff8af (patch)
treeb6795f4b148b595c2c1aedc2d61fd9f0bd04c130
parent54b27c1797fcd32b3f3e9d44e1a149ae396a61e6 (diff)
downloadspark-88134e736829f5f93a82879c08cb191f175ff8af.tar.gz
spark-88134e736829f5f93a82879c08cb191f175ff8af.tar.bz2
spark-88134e736829f5f93a82879c08cb191f175ff8af.zip
[SPARK-16288][SQL] Implement inline table generating function
## What changes were proposed in this pull request? This PR implements `inline` table generating function. ## How was this patch tested? Pass the Jenkins tests with new testcase. Author: Dongjoon Hyun <dongjoon@apache.org> Closes #13976 from dongjoon-hyun/SPARK-16288.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala35
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala59
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala60
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala5
5 files changed, 124 insertions, 36 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index e7f335f4fb..021bec7f5f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -165,6 +165,7 @@ object FunctionRegistry {
expression[Explode]("explode"),
expression[Greatest]("greatest"),
expression[If]("if"),
+ expression[Inline]("inline"),
expression[IsNaN]("isnan"),
expression[IfNull]("ifnull"),
expression[IsNull]("isnull"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 4e91cc5aec..99b97c8ea2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -195,3 +195,38 @@ case class Explode(child: Expression) extends ExplodeBase(child, position = fals
extended = "> SELECT _FUNC_(array(10,20));\n 0\t10\n 1\t20")
// scalastyle:on line.size.limit
case class PosExplode(child: Expression) extends ExplodeBase(child, position = true)
+
+/**
+ * Explodes an array of structs into a table.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(a) - Explodes an array of structs into a table.",
+ extended = "> SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b')));\n [1,a]\n [2,b]")
+case class Inline(child: Expression) extends UnaryExpression with Generator with CodegenFallback {
+
+ override def children: Seq[Expression] = child :: Nil
+
+ override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
+ case ArrayType(et, _) if et.isInstanceOf[StructType] =>
+ TypeCheckResult.TypeCheckSuccess
+ case _ =>
+ TypeCheckResult.TypeCheckFailure(
+ s"input to function $prettyName should be array of struct type, not ${child.dataType}")
+ }
+
+ override def elementSchema: StructType = child.dataType match {
+ case ArrayType(et : StructType, _) => et
+ }
+
+ private lazy val numFields = elementSchema.fields.length
+
+ override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
+ val inputArray = child.eval(input).asInstanceOf[ArrayData]
+ if (inputArray == null) {
+ Nil
+ } else {
+ for (i <- 0 until inputArray.numElements())
+ yield inputArray.getStruct(i, numFields)
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala
index 2aba84141b..e79f89b497 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala
@@ -19,53 +19,48 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.sql.types._
class GeneratorExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
- private def checkTuple(actual: ExplodeBase, expected: Seq[InternalRow]): Unit = {
- assert(actual.eval(null).toSeq === expected)
+ private def checkTuple(actual: Expression, expected: Seq[InternalRow]): Unit = {
+ assert(actual.eval(null).asInstanceOf[TraversableOnce[InternalRow]].toSeq === expected)
}
- private final val int_array = Seq(1, 2, 3)
- private final val str_array = Seq("a", "b", "c")
+ private final val empty_array = CreateArray(Seq.empty)
+ private final val int_array = CreateArray(Seq(1, 2, 3).map(Literal(_)))
+ private final val str_array = CreateArray(Seq("a", "b", "c").map(Literal(_)))
test("explode") {
- val int_correct_answer = Seq(Seq(1), Seq(2), Seq(3))
- val str_correct_answer = Seq(
- Seq(UTF8String.fromString("a")),
- Seq(UTF8String.fromString("b")),
- Seq(UTF8String.fromString("c")))
+ val int_correct_answer = Seq(create_row(1), create_row(2), create_row(3))
+ val str_correct_answer = Seq(create_row("a"), create_row("b"), create_row("c"))
- checkTuple(
- Explode(CreateArray(Seq.empty)),
- Seq.empty)
+ checkTuple(Explode(empty_array), Seq.empty)
+ checkTuple(Explode(int_array), int_correct_answer)
+ checkTuple(Explode(str_array), str_correct_answer)
+ }
- checkTuple(
- Explode(CreateArray(int_array.map(Literal(_)))),
- int_correct_answer.map(InternalRow.fromSeq(_)))
+ test("posexplode") {
+ val int_correct_answer = Seq(create_row(0, 1), create_row(1, 2), create_row(2, 3))
+ val str_correct_answer = Seq(create_row(0, "a"), create_row(1, "b"), create_row(2, "c"))
- checkTuple(
- Explode(CreateArray(str_array.map(Literal(_)))),
- str_correct_answer.map(InternalRow.fromSeq(_)))
+ checkTuple(PosExplode(CreateArray(Seq.empty)), Seq.empty)
+ checkTuple(PosExplode(int_array), int_correct_answer)
+ checkTuple(PosExplode(str_array), str_correct_answer)
}
- test("posexplode") {
- val int_correct_answer = Seq(Seq(0, 1), Seq(1, 2), Seq(2, 3))
- val str_correct_answer = Seq(
- Seq(0, UTF8String.fromString("a")),
- Seq(1, UTF8String.fromString("b")),
- Seq(2, UTF8String.fromString("c")))
+ test("inline") {
+ val correct_answer = Seq(create_row(0, "a"), create_row(1, "b"), create_row(2, "c"))
checkTuple(
- PosExplode(CreateArray(Seq.empty)),
+ Inline(Literal.create(Array(), ArrayType(new StructType().add("id", LongType)))),
Seq.empty)
checkTuple(
- PosExplode(CreateArray(int_array.map(Literal(_)))),
- int_correct_answer.map(InternalRow.fromSeq(_)))
-
- checkTuple(
- PosExplode(CreateArray(str_array.map(Literal(_)))),
- str_correct_answer.map(InternalRow.fromSeq(_)))
+ Inline(CreateArray(Seq(
+ CreateStruct(Seq(Literal(0), Literal("a"))),
+ CreateStruct(Seq(Literal(1), Literal("b"))),
+ CreateStruct(Seq(Literal(2), Literal("c")))
+ ))),
+ correct_answer)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
index 1f0ef34ec1..d8a0aa4d52 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
@@ -89,4 +89,64 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")),
Row(3) :: Nil)
}
+
+ test("inline raises exception on array of null type") {
+ val m = intercept[AnalysisException] {
+ spark.range(2).selectExpr("inline(array())")
+ }.getMessage
+ assert(m.contains("data type mismatch"))
+ }
+
+ test("inline with empty table") {
+ checkAnswer(
+ spark.range(0).selectExpr("inline(array(struct(10, 100)))"),
+ Nil)
+ }
+
+ test("inline on literal") {
+ checkAnswer(
+ spark.range(2).selectExpr("inline(array(struct(10, 100), struct(20, 200), struct(30, 300)))"),
+ Row(10, 100) :: Row(20, 200) :: Row(30, 300) ::
+ Row(10, 100) :: Row(20, 200) :: Row(30, 300) :: Nil)
+ }
+
+ test("inline on column") {
+ val df = Seq((1, 2)).toDF("a", "b")
+
+ checkAnswer(
+ df.selectExpr("inline(array(struct(a), struct(a)))"),
+ Row(1) :: Row(1) :: Nil)
+
+ checkAnswer(
+ df.selectExpr("inline(array(struct(a, b), struct(a, b)))"),
+ Row(1, 2) :: Row(1, 2) :: Nil)
+
+ // Spark think [struct<a:int>, struct<b:int>] is heterogeneous due to name difference.
+ val m = intercept[AnalysisException] {
+ df.selectExpr("inline(array(struct(a), struct(b)))")
+ }.getMessage
+ assert(m.contains("data type mismatch"))
+
+ checkAnswer(
+ df.selectExpr("inline(array(struct(a), named_struct('a', b)))"),
+ Row(1) :: Row(2) :: Nil)
+
+ // Spark think [struct<a:int>, struct<col1:int>] is heterogeneous due to name difference.
+ val m2 = intercept[AnalysisException] {
+ df.selectExpr("inline(array(struct(a), struct(2)))")
+ }.getMessage
+ assert(m2.contains("data type mismatch"))
+
+ checkAnswer(
+ df.selectExpr("inline(array(struct(a), named_struct('a', 2)))"),
+ Row(1) :: Row(2) :: Nil)
+
+ checkAnswer(
+ df.selectExpr("struct(a)").selectExpr("inline(array(*))"),
+ Row(1) :: Nil)
+
+ checkAnswer(
+ df.selectExpr("array(struct(a), named_struct('a', b))").selectExpr("inline(*)"),
+ Row(1) :: Row(2) :: Nil)
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index 53990b8e3b..18b8dafe64 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -241,9 +241,6 @@ private[sql] class HiveSessionCatalog(
"hash", "java_method", "histogram_numeric",
"parse_url", "percentile", "percentile_approx", "reflect", "sentences", "stack", "str_to_map",
"xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long",
- "xpath_number", "xpath_short", "xpath_string",
-
- // table generating function
- "inline"
+ "xpath_number", "xpath_short", "xpath_string"
)
}