diff options
5 files changed, 124 insertions, 36 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e7f335f4fb..021bec7f5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -165,6 +165,7 @@ object FunctionRegistry { expression[Explode]("explode"), expression[Greatest]("greatest"), expression[If]("if"), + expression[Inline]("inline"), expression[IsNaN]("isnan"), expression[IfNull]("ifnull"), expression[IsNull]("isnull"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 4e91cc5aec..99b97c8ea2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -195,3 +195,38 @@ case class Explode(child: Expression) extends ExplodeBase(child, position = fals extended = "> SELECT _FUNC_(array(10,20));\n 0\t10\n 1\t20") // scalastyle:on line.size.limit case class PosExplode(child: Expression) extends ExplodeBase(child, position = true) + +/** + * Explodes an array of structs into a table. + */ +@ExpressionDescription( + usage = "_FUNC_(a) - Explodes an array of structs into a table.", + extended = "> SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b')));\n [1,a]\n [2,b]") +case class Inline(child: Expression) extends UnaryExpression with Generator with CodegenFallback { + + override def children: Seq[Expression] = child :: Nil + + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case ArrayType(et, _) if et.isInstanceOf[StructType] => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should be array of struct type, not ${child.dataType}") + } + + override def elementSchema: StructType = child.dataType match { + case ArrayType(et : StructType, _) => et + } + + private lazy val numFields = elementSchema.fields.length + + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { + val inputArray = child.eval(input).asInstanceOf[ArrayData] + if (inputArray == null) { + Nil + } else { + for (i <- 0 until inputArray.numElements()) + yield inputArray.getStruct(i, numFields) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala index 2aba84141b..e79f89b497 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala @@ -19,53 +19,48 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.sql.types._ class GeneratorExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { - private def checkTuple(actual: ExplodeBase, expected: Seq[InternalRow]): Unit = { - assert(actual.eval(null).toSeq === expected) + private def checkTuple(actual: Expression, expected: Seq[InternalRow]): Unit = { + assert(actual.eval(null).asInstanceOf[TraversableOnce[InternalRow]].toSeq === expected) } - private final val int_array = Seq(1, 2, 3) - private final val str_array = Seq("a", "b", "c") + private final val empty_array = CreateArray(Seq.empty) + private final val int_array = CreateArray(Seq(1, 2, 3).map(Literal(_))) + private final val str_array = CreateArray(Seq("a", "b", "c").map(Literal(_))) test("explode") { - val int_correct_answer = Seq(Seq(1), Seq(2), Seq(3)) - val str_correct_answer = Seq( - Seq(UTF8String.fromString("a")), - Seq(UTF8String.fromString("b")), - Seq(UTF8String.fromString("c"))) + val int_correct_answer = Seq(create_row(1), create_row(2), create_row(3)) + val str_correct_answer = Seq(create_row("a"), create_row("b"), create_row("c")) - checkTuple( - Explode(CreateArray(Seq.empty)), - Seq.empty) + checkTuple(Explode(empty_array), Seq.empty) + checkTuple(Explode(int_array), int_correct_answer) + checkTuple(Explode(str_array), str_correct_answer) + } - checkTuple( - Explode(CreateArray(int_array.map(Literal(_)))), - int_correct_answer.map(InternalRow.fromSeq(_))) + test("posexplode") { + val int_correct_answer = Seq(create_row(0, 1), create_row(1, 2), create_row(2, 3)) + val str_correct_answer = Seq(create_row(0, "a"), create_row(1, "b"), create_row(2, "c")) - checkTuple( - Explode(CreateArray(str_array.map(Literal(_)))), - str_correct_answer.map(InternalRow.fromSeq(_))) + checkTuple(PosExplode(CreateArray(Seq.empty)), Seq.empty) + checkTuple(PosExplode(int_array), int_correct_answer) + checkTuple(PosExplode(str_array), str_correct_answer) } - test("posexplode") { - val int_correct_answer = Seq(Seq(0, 1), Seq(1, 2), Seq(2, 3)) - val str_correct_answer = Seq( - Seq(0, UTF8String.fromString("a")), - Seq(1, UTF8String.fromString("b")), - Seq(2, UTF8String.fromString("c"))) + test("inline") { + val correct_answer = Seq(create_row(0, "a"), create_row(1, "b"), create_row(2, "c")) checkTuple( - PosExplode(CreateArray(Seq.empty)), + Inline(Literal.create(Array(), ArrayType(new StructType().add("id", LongType)))), Seq.empty) checkTuple( - PosExplode(CreateArray(int_array.map(Literal(_)))), - int_correct_answer.map(InternalRow.fromSeq(_))) - - checkTuple( - PosExplode(CreateArray(str_array.map(Literal(_)))), - str_correct_answer.map(InternalRow.fromSeq(_))) + Inline(CreateArray(Seq( + CreateStruct(Seq(Literal(0), Literal("a"))), + CreateStruct(Seq(Literal(1), Literal("b"))), + CreateStruct(Seq(Literal(2), Literal("c"))) + ))), + correct_answer) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index 1f0ef34ec1..d8a0aa4d52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -89,4 +89,64 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")), Row(3) :: Nil) } + + test("inline raises exception on array of null type") { + val m = intercept[AnalysisException] { + spark.range(2).selectExpr("inline(array())") + }.getMessage + assert(m.contains("data type mismatch")) + } + + test("inline with empty table") { + checkAnswer( + spark.range(0).selectExpr("inline(array(struct(10, 100)))"), + Nil) + } + + test("inline on literal") { + checkAnswer( + spark.range(2).selectExpr("inline(array(struct(10, 100), struct(20, 200), struct(30, 300)))"), + Row(10, 100) :: Row(20, 200) :: Row(30, 300) :: + Row(10, 100) :: Row(20, 200) :: Row(30, 300) :: Nil) + } + + test("inline on column") { + val df = Seq((1, 2)).toDF("a", "b") + + checkAnswer( + df.selectExpr("inline(array(struct(a), struct(a)))"), + Row(1) :: Row(1) :: Nil) + + checkAnswer( + df.selectExpr("inline(array(struct(a, b), struct(a, b)))"), + Row(1, 2) :: Row(1, 2) :: Nil) + + // Spark think [struct<a:int>, struct<b:int>] is heterogeneous due to name difference. + val m = intercept[AnalysisException] { + df.selectExpr("inline(array(struct(a), struct(b)))") + }.getMessage + assert(m.contains("data type mismatch")) + + checkAnswer( + df.selectExpr("inline(array(struct(a), named_struct('a', b)))"), + Row(1) :: Row(2) :: Nil) + + // Spark think [struct<a:int>, struct<col1:int>] is heterogeneous due to name difference. + val m2 = intercept[AnalysisException] { + df.selectExpr("inline(array(struct(a), struct(2)))") + }.getMessage + assert(m2.contains("data type mismatch")) + + checkAnswer( + df.selectExpr("inline(array(struct(a), named_struct('a', 2)))"), + Row(1) :: Row(2) :: Nil) + + checkAnswer( + df.selectExpr("struct(a)").selectExpr("inline(array(*))"), + Row(1) :: Nil) + + checkAnswer( + df.selectExpr("array(struct(a), named_struct('a', b))").selectExpr("inline(*)"), + Row(1) :: Row(2) :: Nil) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 53990b8e3b..18b8dafe64 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -241,9 +241,6 @@ private[sql] class HiveSessionCatalog( "hash", "java_method", "histogram_numeric", "parse_url", "percentile", "percentile_approx", "reflect", "sentences", "stack", "str_to_map", "xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long", - "xpath_number", "xpath_short", "xpath_string", - - // table generating function - "inline" + "xpath_number", "xpath_short", "xpath_string" ) } |