aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala20
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala15
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala76
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala3
9 files changed, 150 insertions, 16 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index bd9037ec43..98851cb855 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1619,11 +1619,18 @@ class Analyzer(
case _ => expr
}
- /** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */
private object AliasedGenerator {
- def unapply(e: Expression): Option[(Generator, Seq[String])] = e match {
- case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil))
- case MultiAlias(g: Generator, names) if g.resolved => Some(g, names)
+ /**
+ * Extracts a [[Generator]] expression, any names assigned by aliases to the outputs
+ * and the outer flag. The outer flag is used when joining the generator output.
+ * @param e the [[Expression]]
+ * @return (the [[Generator]], seq of output names, outer flag)
+ */
+ def unapply(e: Expression): Option[(Generator, Seq[String], Boolean)] = e match {
+ case Alias(GeneratorOuter(g: Generator), name) if g.resolved => Some((g, name :: Nil, true))
+ case MultiAlias(GeneratorOuter(g: Generator), names) if g.resolved => Some(g, names, true)
+ case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil, false))
+ case MultiAlias(g: Generator, names) if g.resolved => Some(g, names, false)
case _ => None
}
}
@@ -1644,7 +1651,8 @@ class Analyzer(
var resolvedGenerator: Generate = null
val newProjectList = projectList.flatMap {
- case AliasedGenerator(generator, names) if generator.childrenResolved =>
+
+ case AliasedGenerator(generator, names, outer) if generator.childrenResolved =>
// It's a sanity check, this should not happen as the previous case will throw
// exception earlier.
assert(resolvedGenerator == null, "More than one generator found in SELECT.")
@@ -1653,7 +1661,7 @@ class Analyzer(
Generate(
generator,
join = projectList.size > 1, // Only join if there are other expressions in SELECT.
- outer = false,
+ outer = outer,
qualifier = None,
generatorOutput = ResolveGenerate.makeGeneratorOutput(generator, names),
child)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 2b214c3c9d..eea3740be8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -163,9 +163,11 @@ object FunctionRegistry {
expression[Abs]("abs"),
expression[Coalesce]("coalesce"),
expression[Explode]("explode"),
+ expressionGeneratorOuter[Explode]("explode_outer"),
expression[Greatest]("greatest"),
expression[If]("if"),
expression[Inline]("inline"),
+ expressionGeneratorOuter[Inline]("inline_outer"),
expression[IsNaN]("isnan"),
expression[IfNull]("ifnull"),
expression[IsNull]("isnull"),
@@ -176,6 +178,7 @@ object FunctionRegistry {
expression[Nvl]("nvl"),
expression[Nvl2]("nvl2"),
expression[PosExplode]("posexplode"),
+ expressionGeneratorOuter[PosExplode]("posexplode_outer"),
expression[Rand]("rand"),
expression[Randn]("randn"),
expression[Stack]("stack"),
@@ -508,4 +511,13 @@ object FunctionRegistry {
new ExpressionInfo(clazz.getCanonicalName, name)
}
}
+
+ private def expressionGeneratorOuter[T <: Generator : ClassTag](name: String)
+ : (String, (ExpressionInfo, FunctionBuilder)) = {
+ val (_, (info, generatorBuilder)) = expression[T](name)
+ val outerBuilder = (args: Seq[Expression]) => {
+ GeneratorOuter(generatorBuilder(args).asInstanceOf[Generator])
+ }
+ (name, (info, outerBuilder))
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 6c38f4998e..1b98c30d37 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -204,6 +204,15 @@ case class Stack(children: Seq[Expression]) extends Generator {
}
}
+case class GeneratorOuter(child: Generator) extends UnaryExpression with Generator {
+ final override def eval(input: InternalRow = null): TraversableOnce[InternalRow] =
+ throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
+
+ final override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
+ throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
+
+ override def elementSchema: StructType = child.elementSchema
+}
/**
* A base class for [[Explode]] and [[PosExplode]].
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 48f68a6415..3bd314315d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -101,10 +101,17 @@ case class Generate(
override def producedAttributes: AttributeSet = AttributeSet(generatorOutput)
- def qualifiedGeneratorOutput: Seq[Attribute] = qualifier.map { q =>
- // prepend the new qualifier to the existed one
- generatorOutput.map(a => a.withQualifier(Some(q)))
- }.getOrElse(generatorOutput)
+ def qualifiedGeneratorOutput: Seq[Attribute] = {
+ val qualifiedOutput = qualifier.map { q =>
+ // prepend the new qualifier to the existed one
+ generatorOutput.map(a => a.withQualifier(Some(q)))
+ }.getOrElse(generatorOutput)
+ val nullableOutput = qualifiedOutput.map {
+ // if outer, make all attributes nullable, otherwise keep existing nullability
+ a => a.withNullability(outer || a.nullable)
+ }
+ nullableOutput
+ }
def output: Seq[Attribute] = {
if (join) child.output ++ qualifiedGeneratorOutput else qualifiedGeneratorOutput
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index a3f581ff27..60182befd7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -166,10 +166,7 @@ class Column(val expr: Expression) extends Logging {
// Leave an unaliased generator with an empty list of names since the analyzer will generate
// the correct defaults after the nested expression's type has been resolved.
- case explode: Explode => MultiAlias(explode, Nil)
- case explode: PosExplode => MultiAlias(explode, Nil)
-
- case jt: JsonTuple => MultiAlias(jt, Nil)
+ case g: Generator => MultiAlias(g, Nil)
case func: UnresolvedFunction => UnresolvedAlias(func, Some(Column.generateAlias))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
index 04b16af4ea..b52f5c4d4a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
@@ -162,11 +162,15 @@ case class GenerateExec(
val index = ctx.freshName("index")
// Add a check if the generate outer flag is true.
- val checks = optionalCode(outer, data.isNull)
+ val checks = optionalCode(outer, s"($index == -1)")
// Add position
val position = if (e.position) {
- Seq(ExprCode("", "false", index))
+ if (outer) {
+ Seq(ExprCode("", s"$index == -1", index))
+ } else {
+ Seq(ExprCode("", "false", index))
+ }
} else {
Seq.empty
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index cabe1f4563..c86ae5be9e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2871,6 +2871,15 @@ object functions {
def explode(e: Column): Column = withExpr { Explode(e.expr) }
/**
+ * Creates a new row for each element in the given array or map column.
+ * Unlike explode, if the array/map is null or empty then null is produced.
+ *
+ * @group collection_funcs
+ * @since 2.2.0
+ */
+ def explode_outer(e: Column): Column = withExpr { GeneratorOuter(Explode(e.expr)) }
+
+ /**
* Creates a new row for each element with position in the given array or map column.
*
* @group collection_funcs
@@ -2879,6 +2888,15 @@ object functions {
def posexplode(e: Column): Column = withExpr { PosExplode(e.expr) }
/**
+ * Creates a new row for each element with position in the given array or map column.
+ * Unlike posexplode, if the array/map is null or empty then the row (null, null) is produced.
+ *
+ * @group collection_funcs
+ * @since 2.2.0
+ */
+ def posexplode_outer(e: Column): Column = withExpr { GeneratorOuter(PosExplode(e.expr)) }
+
+ /**
* Extracts json object from a json string based on json path specified, and returns json string
* of the extracted json object. It will return null if the input json string is invalid.
*
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
index f0995ea1d0..b9871afd59 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
@@ -87,6 +87,13 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
Row(1) :: Row(2) :: Row(3) :: Nil)
}
+ test("single explode_outer") {
+ val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList")
+ checkAnswer(
+ df.select(explode_outer('intList)),
+ Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil)
+ }
+
test("single posexplode") {
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
checkAnswer(
@@ -94,6 +101,13 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Nil)
}
+ test("single posexplode_outer") {
+ val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList")
+ checkAnswer(
+ df.select(posexplode_outer('intList)),
+ Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Row(null, null) :: Nil)
+ }
+
test("explode and other columns") {
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
@@ -110,6 +124,26 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
Row(1, Seq(1, 2, 3), 3) :: Nil)
}
+ test("explode_outer and other columns") {
+ val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList")
+
+ checkAnswer(
+ df.select($"a", explode_outer('intList)),
+ Row(1, 1) ::
+ Row(1, 2) ::
+ Row(1, 3) ::
+ Row(2, null) ::
+ Nil)
+
+ checkAnswer(
+ df.select($"*", explode_outer('intList)),
+ Row(1, Seq(1, 2, 3), 1) ::
+ Row(1, Seq(1, 2, 3), 2) ::
+ Row(1, Seq(1, 2, 3), 3) ::
+ Row(2, Seq(), null) ::
+ Nil)
+ }
+
test("aliased explode") {
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
@@ -122,6 +156,18 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
Row(6) :: Nil)
}
+ test("aliased explode_outer") {
+ val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList")
+
+ checkAnswer(
+ df.select(explode_outer('intList).as('int)).select('int),
+ Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil)
+
+ checkAnswer(
+ df.select(explode('intList).as('int)).select(sum('int)),
+ Row(6) :: Nil)
+ }
+
test("explode on map") {
val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
@@ -130,6 +176,15 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
Row("a", "b"))
}
+ test("explode_outer on map") {
+ val df = Seq((1, Map("a" -> "b")), (2, Map[String, String]()),
+ (3, Map("c" -> "d"))).toDF("a", "map")
+
+ checkAnswer(
+ df.select(explode_outer('map)),
+ Row("a", "b") :: Row(null, null) :: Row("c", "d") :: Nil)
+ }
+
test("explode on map with aliases") {
val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
@@ -138,6 +193,14 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
Row("a", "b"))
}
+ test("explode_outer on map with aliases") {
+ val df = Seq((3, None), (1, Some(Map("a" -> "b")))).toDF("a", "map")
+
+ checkAnswer(
+ df.select(explode_outer('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"),
+ Row("a", "b") :: Row(null, null) :: Nil)
+ }
+
test("self join explode") {
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
val exploded = df.select(explode('intList).as('i))
@@ -207,6 +270,19 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
Row(1) :: Row(2) :: Nil)
}
+ test("inline_outer") {
+ val df = Seq((1, "2"), (3, "4"), (5, "6")).toDF("col1", "col2")
+ val df2 = df.select(when('col1 === 1, null).otherwise(array(struct('col1, 'col2))).as("col1"))
+ checkAnswer(
+ df2.selectExpr("inline(col1)"),
+ Row(3, "4") :: Row(5, "6") :: Nil
+ )
+ checkAnswer(
+ df2.selectExpr("inline_outer(col1)"),
+ Row(null, null) :: Row(3, "4") :: Row(5, "6") :: Nil
+ )
+ }
+
test("SPARK-14986: Outer lateral view with empty generate expression") {
checkAnswer(
sql("select nil from values 1 lateral view outer explode(array()) n as nil"),
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
index 27ea167b90..df9390aec7 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
@@ -93,6 +93,7 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils {
checkSqlGeneration("SELECT array(1,2,3)")
checkSqlGeneration("SELECT coalesce(null, 1, 2)")
checkSqlGeneration("SELECT explode(array(1,2,3))")
+ checkSqlGeneration("SELECT explode_outer(array())")
checkSqlGeneration("SELECT greatest(1,null,3)")
checkSqlGeneration("SELECT if(1==2, 'yes', 'no')")
checkSqlGeneration("SELECT isnan(15), isnan('invalid')")
@@ -102,6 +103,8 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils {
checkSqlGeneration("SELECT map(1, 'a', 2, 'b')")
checkSqlGeneration("SELECT named_struct('c1',1,'c2',2,'c3',3)")
checkSqlGeneration("SELECT nanvl(a, 5), nanvl(b, 10), nanvl(d, c) from t2")
+ checkSqlGeneration("SELECT posexplode_outer(array())")
+ checkSqlGeneration("SELECT inline_outer(array(struct('a', 1)))")
checkSqlGeneration("SELECT rand(1)")
checkSqlGeneration("SELECT randn(3)")
checkSqlGeneration("SELECT struct(1,2,3)")