diff options
9 files changed, 150 insertions, 16 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index bd9037ec43..98851cb855 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1619,11 +1619,18 @@ class Analyzer( case _ => expr } - /** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */ private object AliasedGenerator { - def unapply(e: Expression): Option[(Generator, Seq[String])] = e match { - case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil)) - case MultiAlias(g: Generator, names) if g.resolved => Some(g, names) + /** + * Extracts a [[Generator]] expression, any names assigned by aliases to the outputs + * and the outer flag. The outer flag is used when joining the generator output. + * @param e the [[Expression]] + * @return (the [[Generator]], seq of output names, outer flag) + */ + def unapply(e: Expression): Option[(Generator, Seq[String], Boolean)] = e match { + case Alias(GeneratorOuter(g: Generator), name) if g.resolved => Some((g, name :: Nil, true)) + case MultiAlias(GeneratorOuter(g: Generator), names) if g.resolved => Some(g, names, true) + case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil, false)) + case MultiAlias(g: Generator, names) if g.resolved => Some(g, names, false) case _ => None } } @@ -1644,7 +1651,8 @@ class Analyzer( var resolvedGenerator: Generate = null val newProjectList = projectList.flatMap { - case AliasedGenerator(generator, names) if generator.childrenResolved => + + case AliasedGenerator(generator, names, outer) if generator.childrenResolved => // It's a sanity check, this should not happen as the previous case will throw // exception earlier. assert(resolvedGenerator == null, "More than one generator found in SELECT.") @@ -1653,7 +1661,7 @@ class Analyzer( Generate( generator, join = projectList.size > 1, // Only join if there are other expressions in SELECT. - outer = false, + outer = outer, qualifier = None, generatorOutput = ResolveGenerate.makeGeneratorOutput(generator, names), child) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 2b214c3c9d..eea3740be8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -163,9 +163,11 @@ object FunctionRegistry { expression[Abs]("abs"), expression[Coalesce]("coalesce"), expression[Explode]("explode"), + expressionGeneratorOuter[Explode]("explode_outer"), expression[Greatest]("greatest"), expression[If]("if"), expression[Inline]("inline"), + expressionGeneratorOuter[Inline]("inline_outer"), expression[IsNaN]("isnan"), expression[IfNull]("ifnull"), expression[IsNull]("isnull"), @@ -176,6 +178,7 @@ object FunctionRegistry { expression[Nvl]("nvl"), expression[Nvl2]("nvl2"), expression[PosExplode]("posexplode"), + expressionGeneratorOuter[PosExplode]("posexplode_outer"), expression[Rand]("rand"), expression[Randn]("randn"), expression[Stack]("stack"), @@ -508,4 +511,13 @@ object FunctionRegistry { new ExpressionInfo(clazz.getCanonicalName, name) } } + + private def expressionGeneratorOuter[T <: Generator : ClassTag](name: String) + : (String, (ExpressionInfo, FunctionBuilder)) = { + val (_, (info, generatorBuilder)) = expression[T](name) + val outerBuilder = (args: Seq[Expression]) => { + GeneratorOuter(generatorBuilder(args).asInstanceOf[Generator]) + } + (name, (info, outerBuilder)) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 6c38f4998e..1b98c30d37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -204,6 +204,15 @@ case class Stack(children: Seq[Expression]) extends Generator { } } +case class GeneratorOuter(child: Generator) extends UnaryExpression with Generator { + final override def eval(input: InternalRow = null): TraversableOnce[InternalRow] = + throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") + + final override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") + + override def elementSchema: StructType = child.elementSchema +} /** * A base class for [[Explode]] and [[PosExplode]]. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 48f68a6415..3bd314315d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -101,10 +101,17 @@ case class Generate( override def producedAttributes: AttributeSet = AttributeSet(generatorOutput) - def qualifiedGeneratorOutput: Seq[Attribute] = qualifier.map { q => - // prepend the new qualifier to the existed one - generatorOutput.map(a => a.withQualifier(Some(q))) - }.getOrElse(generatorOutput) + def qualifiedGeneratorOutput: Seq[Attribute] = { + val qualifiedOutput = qualifier.map { q => + // prepend the new qualifier to the existed one + generatorOutput.map(a => a.withQualifier(Some(q))) + }.getOrElse(generatorOutput) + val nullableOutput = qualifiedOutput.map { + // if outer, make all attributes nullable, otherwise keep existing nullability + a => a.withNullability(outer || a.nullable) + } + nullableOutput + } def output: Seq[Attribute] = { if (join) child.output ++ qualifiedGeneratorOutput else qualifiedGeneratorOutput diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index a3f581ff27..60182befd7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -166,10 +166,7 @@ class Column(val expr: Expression) extends Logging { // Leave an unaliased generator with an empty list of names since the analyzer will generate // the correct defaults after the nested expression's type has been resolved. - case explode: Explode => MultiAlias(explode, Nil) - case explode: PosExplode => MultiAlias(explode, Nil) - - case jt: JsonTuple => MultiAlias(jt, Nil) + case g: Generator => MultiAlias(g, Nil) case func: UnresolvedFunction => UnresolvedAlias(func, Some(Column.generateAlias)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 04b16af4ea..b52f5c4d4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -162,11 +162,15 @@ case class GenerateExec( val index = ctx.freshName("index") // Add a check if the generate outer flag is true. - val checks = optionalCode(outer, data.isNull) + val checks = optionalCode(outer, s"($index == -1)") // Add position val position = if (e.position) { - Seq(ExprCode("", "false", index)) + if (outer) { + Seq(ExprCode("", s"$index == -1", index)) + } else { + Seq(ExprCode("", "false", index)) + } } else { Seq.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index cabe1f4563..c86ae5be9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2871,6 +2871,15 @@ object functions { def explode(e: Column): Column = withExpr { Explode(e.expr) } /** + * Creates a new row for each element in the given array or map column. + * Unlike explode, if the array/map is null or empty then null is produced. + * + * @group collection_funcs + * @since 2.2.0 + */ + def explode_outer(e: Column): Column = withExpr { GeneratorOuter(Explode(e.expr)) } + + /** * Creates a new row for each element with position in the given array or map column. * * @group collection_funcs @@ -2879,6 +2888,15 @@ object functions { def posexplode(e: Column): Column = withExpr { PosExplode(e.expr) } /** + * Creates a new row for each element with position in the given array or map column. + * Unlike posexplode, if the array/map is null or empty then the row (null, null) is produced. + * + * @group collection_funcs + * @since 2.2.0 + */ + def posexplode_outer(e: Column): Column = withExpr { GeneratorOuter(PosExplode(e.expr)) } + + /** * Extracts json object from a json string based on json path specified, and returns json string * of the extracted json object. It will return null if the input json string is invalid. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index f0995ea1d0..b9871afd59 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -87,6 +87,13 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { Row(1) :: Row(2) :: Row(3) :: Nil) } + test("single explode_outer") { + val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList") + checkAnswer( + df.select(explode_outer('intList)), + Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) + } + test("single posexplode") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") checkAnswer( @@ -94,6 +101,13 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Nil) } + test("single posexplode_outer") { + val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList") + checkAnswer( + df.select(posexplode_outer('intList)), + Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Row(null, null) :: Nil) + } + test("explode and other columns") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") @@ -110,6 +124,26 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { Row(1, Seq(1, 2, 3), 3) :: Nil) } + test("explode_outer and other columns") { + val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList") + + checkAnswer( + df.select($"a", explode_outer('intList)), + Row(1, 1) :: + Row(1, 2) :: + Row(1, 3) :: + Row(2, null) :: + Nil) + + checkAnswer( + df.select($"*", explode_outer('intList)), + Row(1, Seq(1, 2, 3), 1) :: + Row(1, Seq(1, 2, 3), 2) :: + Row(1, Seq(1, 2, 3), 3) :: + Row(2, Seq(), null) :: + Nil) + } + test("aliased explode") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") @@ -122,6 +156,18 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { Row(6) :: Nil) } + test("aliased explode_outer") { + val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList") + + checkAnswer( + df.select(explode_outer('intList).as('int)).select('int), + Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) + + checkAnswer( + df.select(explode('intList).as('int)).select(sum('int)), + Row(6) :: Nil) + } + test("explode on map") { val df = Seq((1, Map("a" -> "b"))).toDF("a", "map") @@ -130,6 +176,15 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { Row("a", "b")) } + test("explode_outer on map") { + val df = Seq((1, Map("a" -> "b")), (2, Map[String, String]()), + (3, Map("c" -> "d"))).toDF("a", "map") + + checkAnswer( + df.select(explode_outer('map)), + Row("a", "b") :: Row(null, null) :: Row("c", "d") :: Nil) + } + test("explode on map with aliases") { val df = Seq((1, Map("a" -> "b"))).toDF("a", "map") @@ -138,6 +193,14 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { Row("a", "b")) } + test("explode_outer on map with aliases") { + val df = Seq((3, None), (1, Some(Map("a" -> "b")))).toDF("a", "map") + + checkAnswer( + df.select(explode_outer('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"), + Row("a", "b") :: Row(null, null) :: Nil) + } + test("self join explode") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") val exploded = df.select(explode('intList).as('i)) @@ -207,6 +270,19 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { Row(1) :: Row(2) :: Nil) } + test("inline_outer") { + val df = Seq((1, "2"), (3, "4"), (5, "6")).toDF("col1", "col2") + val df2 = df.select(when('col1 === 1, null).otherwise(array(struct('col1, 'col2))).as("col1")) + checkAnswer( + df2.selectExpr("inline(col1)"), + Row(3, "4") :: Row(5, "6") :: Nil + ) + checkAnswer( + df2.selectExpr("inline_outer(col1)"), + Row(null, null) :: Row(3, "4") :: Row(5, "6") :: Nil + ) + } + test("SPARK-14986: Outer lateral view with empty generate expression") { checkAnswer( sql("select nil from values 1 lateral view outer explode(array()) n as nil"), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala index 27ea167b90..df9390aec7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala @@ -93,6 +93,7 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkSqlGeneration("SELECT array(1,2,3)") checkSqlGeneration("SELECT coalesce(null, 1, 2)") checkSqlGeneration("SELECT explode(array(1,2,3))") + checkSqlGeneration("SELECT explode_outer(array())") checkSqlGeneration("SELECT greatest(1,null,3)") checkSqlGeneration("SELECT if(1==2, 'yes', 'no')") checkSqlGeneration("SELECT isnan(15), isnan('invalid')") @@ -102,6 +103,8 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkSqlGeneration("SELECT map(1, 'a', 2, 'b')") checkSqlGeneration("SELECT named_struct('c1',1,'c2',2,'c3',3)") checkSqlGeneration("SELECT nanvl(a, 5), nanvl(b, 10), nanvl(d, c) from t2") + checkSqlGeneration("SELECT posexplode_outer(array())") + checkSqlGeneration("SELECT inline_outer(array(struct('a', 1)))") checkSqlGeneration("SELECT rand(1)") checkSqlGeneration("SELECT randn(3)") checkSqlGeneration("SELECT struct(1,2,3)") |