aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@databricks.com>2016-11-19 23:55:09 -0800
committerReynold Xin <rxin@databricks.com>2016-11-19 23:55:09 -0800
commit7ca7a635242377634c302b7816ce60bd9c908527 (patch)
tree7145deb5c19711a933ad0d9be2a2bb4de5f263c9 /sql
parenta64f25d8b403b17ff68c9575f6f35b22e5b62427 (diff)
downloadspark-7ca7a635242377634c302b7816ce60bd9c908527.tar.gz
spark-7ca7a635242377634c302b7816ce60bd9c908527.tar.bz2
spark-7ca7a635242377634c302b7816ce60bd9c908527.zip
[SPARK-15214][SQL] Code-generation for Generate
## What changes were proposed in this pull request? This PR adds code generation to `Generate`. It supports two code paths: - General `TraversableOnce` based iteration. This used for regular `Generator` (code generation supporting) expressions. This code path expects the expression to return a `TraversableOnce[InternalRow]` and it will iterate over the returned collection. This PR adds code generation for the `stack` generator. - Specialized `ArrayData/MapData` based iteration. This is used for the `explode`, `posexplode` & `inline` functions and operates directly on the `ArrayData`/`MapData` result that the child of the generator returns. ### Benchmarks I have added some benchmarks and it seems we can create a nice speedup for explode: #### Environment ``` Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 Intel(R) Core(TM) i7-4980HQ CPU 2.80GHz ``` #### Explode Array ##### Before ``` generate explode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate explode array wholestage off 7377 / 7607 2.3 439.7 1.0X generate explode array wholestage on 6055 / 6086 2.8 360.9 1.2X ``` ##### After ``` generate explode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate explode array wholestage off 7432 / 7696 2.3 443.0 1.0X generate explode array wholestage on 631 / 646 26.6 37.6 11.8X ``` #### Explode Map ##### Before ``` generate explode map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate explode map wholestage off 12792 / 12848 1.3 762.5 1.0X generate explode map wholestage on 11181 / 11237 1.5 666.5 1.1X ``` ##### After ``` generate explode map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate explode map wholestage off 10949 / 10972 1.5 652.6 1.0X generate explode map wholestage on 870 / 913 19.3 51.9 12.6X ``` #### Posexplode ##### Before ``` generate posexplode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate posexplode array wholestage off 7547 / 7580 2.2 449.8 1.0X generate posexplode array wholestage on 5786 / 5838 2.9 344.9 1.3X ``` ##### After ``` generate posexplode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate posexplode array wholestage off 7535 / 7548 2.2 449.1 1.0X generate posexplode array wholestage on 620 / 624 27.1 37.0 12.1X ``` #### Inline ##### Before ``` generate inline array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate inline array wholestage off 6935 / 6978 2.4 413.3 1.0X generate inline array wholestage on 6360 / 6400 2.6 379.1 1.1X ``` ##### After ``` generate inline array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate inline array wholestage off 6940 / 6966 2.4 413.6 1.0X generate inline array wholestage on 1002 / 1012 16.7 59.7 6.9X ``` #### Stack ##### Before ``` generate stack: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate stack wholestage off 12980 / 13104 1.3 773.7 1.0X generate stack wholestage on 11566 / 11580 1.5 689.4 1.1X ``` ##### After ``` generate stack: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate stack wholestage off 12875 / 12949 1.3 767.4 1.0X generate stack wholestage on 840 / 845 20.0 50.0 15.3X ``` ## How was this patch tested? Existing tests. Author: Herman van Hovell <hvanhovell@databricks.com> Author: Herman van Hovell <hvanhovell@questtec.nl> Closes #13065 from hvanhovell/SPARK-15214.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala110
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala202
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala34
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala32
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala99
7 files changed, 463 insertions, 37 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index d042bfb63d..6c38f4998e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -17,10 +17,12 @@
package org.apache.spark.sql.catalyst.expressions
+import scala.collection.mutable
+
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
@@ -60,6 +62,26 @@ trait Generator extends Expression {
* rows can be made here.
*/
def terminate(): TraversableOnce[InternalRow] = Nil
+
+ /**
+ * Check if this generator supports code generation.
+ */
+ def supportCodegen: Boolean = !isInstanceOf[CodegenFallback]
+}
+
+/**
+ * A collection producing [[Generator]]. This trait provides a different path for code generation,
+ * by allowing code generation to return either an [[ArrayData]] or a [[MapData]] object.
+ */
+trait CollectionGenerator extends Generator {
+ /** The position of an element within the collection should also be returned. */
+ def position: Boolean
+
+ /** Rows will be inlined during generation. */
+ def inline: Boolean
+
+ /** The type of the returned collection object. */
+ def collectionType: DataType = dataType
}
/**
@@ -77,7 +99,9 @@ case class UserDefinedGenerator(
private def initializeConverters(): Unit = {
inputRow = new InterpretedProjection(children)
convertToScala = {
- val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true)))
+ val inputSchema = StructType(children.map { e =>
+ StructField(e.simpleString, e.dataType, nullable = true)
+ })
CatalystTypeConverters.createToScalaConverter(inputSchema)
}.asInstanceOf[InternalRow => Row]
}
@@ -109,8 +133,7 @@ case class UserDefinedGenerator(
1 2
3 NULL
""")
-case class Stack(children: Seq[Expression])
- extends Expression with Generator with CodegenFallback {
+case class Stack(children: Seq[Expression]) extends Generator {
private lazy val numRows = children.head.eval().asInstanceOf[Int]
private lazy val numFields = Math.ceil((children.length - 1.0) / numRows).toInt
@@ -149,21 +172,50 @@ case class Stack(children: Seq[Expression])
InternalRow(fields: _*)
}
}
+
+
+ /**
+ * Only support code generation when stack produces 50 rows or less.
+ */
+ override def supportCodegen: Boolean = numRows <= 50
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ // Rows - we write these into an array.
+ val rowData = ctx.freshName("rows")
+ ctx.addMutableState("InternalRow[]", rowData, s"this.$rowData = new InternalRow[$numRows];")
+ val values = children.tail
+ val dataTypes = values.take(numFields).map(_.dataType)
+ val code = ctx.splitExpressions(ctx.INPUT_ROW, Seq.tabulate(numRows) { row =>
+ val fields = Seq.tabulate(numFields) { col =>
+ val index = row * numFields + col
+ if (index < values.length) values(index) else Literal(null, dataTypes(col))
+ }
+ val eval = CreateStruct(fields).genCode(ctx)
+ s"${eval.code}\nthis.$rowData[$row] = ${eval.value};"
+ })
+
+ // Create the collection.
+ val wrapperClass = classOf[mutable.WrappedArray[_]].getName
+ ctx.addMutableState(
+ s"$wrapperClass<InternalRow>",
+ ev.value,
+ s"this.${ev.value} = $wrapperClass$$.MODULE$$.make(this.$rowData);")
+ ev.copy(code = code, isNull = "false")
+ }
}
/**
- * A base class for Explode and PosExplode
+ * A base class for [[Explode]] and [[PosExplode]].
*/
-abstract class ExplodeBase(child: Expression, position: Boolean)
- extends UnaryExpression with Generator with CodegenFallback with Serializable {
+abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with Serializable {
+ override val inline: Boolean = false
- override def checkInputDataTypes(): TypeCheckResult = {
- if (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) {
+ override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
+ case _: ArrayType | _: MapType =>
TypeCheckResult.TypeCheckSuccess
- } else {
+ case _ =>
TypeCheckResult.TypeCheckFailure(
s"input to function explode should be array or map type, not ${child.dataType}")
- }
}
// hive-compatible default alias for explode function ("col" for array, "key", "value" for map)
@@ -171,7 +223,7 @@ abstract class ExplodeBase(child: Expression, position: Boolean)
case ArrayType(et, containsNull) =>
if (position) {
new StructType()
- .add("pos", IntegerType, false)
+ .add("pos", IntegerType, nullable = false)
.add("col", et, containsNull)
} else {
new StructType()
@@ -180,12 +232,12 @@ abstract class ExplodeBase(child: Expression, position: Boolean)
case MapType(kt, vt, valueContainsNull) =>
if (position) {
new StructType()
- .add("pos", IntegerType, false)
- .add("key", kt, false)
+ .add("pos", IntegerType, nullable = false)
+ .add("key", kt, nullable = false)
.add("value", vt, valueContainsNull)
} else {
new StructType()
- .add("key", kt, false)
+ .add("key", kt, nullable = false)
.add("value", vt, valueContainsNull)
}
}
@@ -218,6 +270,12 @@ abstract class ExplodeBase(child: Expression, position: Boolean)
}
}
}
+
+ override def collectionType: DataType = child.dataType
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ child.genCode(ctx)
+ }
}
/**
@@ -239,7 +297,9 @@ abstract class ExplodeBase(child: Expression, position: Boolean)
20
""")
// scalastyle:on line.size.limit
-case class Explode(child: Expression) extends ExplodeBase(child, position = false)
+case class Explode(child: Expression) extends ExplodeBase {
+ override val position: Boolean = false
+}
/**
* Given an input array produces a sequence of rows for each position and value in the array.
@@ -260,7 +320,9 @@ case class Explode(child: Expression) extends ExplodeBase(child, position = fals
1 20
""")
// scalastyle:on line.size.limit
-case class PosExplode(child: Expression) extends ExplodeBase(child, position = true)
+case class PosExplode(child: Expression) extends ExplodeBase {
+ override val position = true
+}
/**
* Explodes an array of structs into a table.
@@ -273,10 +335,12 @@ case class PosExplode(child: Expression) extends ExplodeBase(child, position = t
1 a
2 b
""")
-case class Inline(child: Expression) extends UnaryExpression with Generator with CodegenFallback {
+case class Inline(child: Expression) extends UnaryExpression with CollectionGenerator {
+ override val inline: Boolean = true
+ override val position: Boolean = false
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
- case ArrayType(et, _) if et.isInstanceOf[StructType] =>
+ case ArrayType(st: StructType, _) =>
TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(
@@ -284,9 +348,11 @@ case class Inline(child: Expression) extends UnaryExpression with Generator with
}
override def elementSchema: StructType = child.dataType match {
- case ArrayType(et : StructType, _) => et
+ case ArrayType(st: StructType, _) => st
}
+ override def collectionType: DataType = child.dataType
+
private lazy val numFields = elementSchema.fields.length
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
@@ -298,4 +364,8 @@ case class Inline(child: Expression) extends UnaryExpression with Generator with
yield inputArray.getStruct(i, numFields)
}
}
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ child.genCode(ctx)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
index 1e39b24fe8..2db2a043e5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
@@ -17,7 +17,8 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.types.{DataType, IntegerType}
class SubexpressionEliminationSuite extends SparkFunSuite {
test("Semantic equals and hash") {
@@ -162,13 +163,18 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
test("Children of CodegenFallback") {
val one = Literal(1)
val two = Add(one, one)
- val explode = Explode(two)
- val add = Add(two, explode)
+ val fallback = CodegenFallbackExpression(two)
+ val add = Add(two, fallback)
- var equivalence = new EquivalentExpressions
+ val equivalence = new EquivalentExpressions
equivalence.addExprTree(add, true)
- // the `two` inside `explode` should not be added
+ // the `two` inside `fallback` should not be added
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)
assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode
}
}
+
+case class CodegenFallbackExpression(child: Expression)
+ extends UnaryExpression with CodegenFallback {
+ override def dataType: DataType = child.dataType
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
index 19fbf0c162..f80214af43 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
@@ -20,8 +20,10 @@ package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
/**
* For lazy computing, be sure the generator.terminate() called in the very last
@@ -40,6 +42,10 @@ private[execution] sealed case class LazyIterator(func: () => TraversableOnce[In
* output of each into a new stream of rows. This operation is similar to a `flatMap` in functional
* programming with one important additional feature, which allows the input rows to be joined with
* their output.
+ *
+ * This operator supports whole stage code generation for generators that do not implement
+ * terminate().
+ *
* @param generator the generator expression
* @param join when true, each output row is implicitly joined with the input tuple that produced
* it.
@@ -54,7 +60,7 @@ case class GenerateExec(
outer: Boolean,
output: Seq[Attribute],
child: SparkPlan)
- extends UnaryExecNode {
+ extends UnaryExecNode with CodegenSupport {
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
@@ -103,5 +109,197 @@ case class GenerateExec(
}
}
}
-}
+ override def supportCodegen: Boolean = generator.supportCodegen
+
+ override def inputRDDs(): Seq[RDD[InternalRow]] = {
+ child.asInstanceOf[CodegenSupport].inputRDDs()
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
+ child.asInstanceOf[CodegenSupport].produce(ctx, this)
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
+ ctx.currentVars = input
+ ctx.copyResult = true
+
+ // Add input rows to the values when we are joining
+ val values = if (join) {
+ input
+ } else {
+ Seq.empty
+ }
+
+ boundGenerator match {
+ case e: CollectionGenerator => codeGenCollection(ctx, e, values, row)
+ case g => codeGenTraversableOnce(ctx, g, values, row)
+ }
+ }
+
+ /**
+ * Generate code for [[CollectionGenerator]] expressions.
+ */
+ private def codeGenCollection(
+ ctx: CodegenContext,
+ e: CollectionGenerator,
+ input: Seq[ExprCode],
+ row: ExprCode): String = {
+
+ // Generate code for the generator.
+ val data = e.genCode(ctx)
+
+ // Generate looping variables.
+ val index = ctx.freshName("index")
+
+ // Add a check if the generate outer flag is true.
+ val checks = optionalCode(outer, data.isNull)
+
+ // Add position
+ val position = if (e.position) {
+ Seq(ExprCode("", "false", index))
+ } else {
+ Seq.empty
+ }
+
+ // Generate code for either ArrayData or MapData
+ val (initMapData, updateRowData, values) = e.collectionType match {
+ case ArrayType(st: StructType, nullable) if e.inline =>
+ val row = codeGenAccessor(ctx, data.value, "col", index, st, nullable, checks)
+ val fieldChecks = checks ++ optionalCode(nullable, row.isNull)
+ val columns = st.fields.toSeq.zipWithIndex.map { case (f, i) =>
+ codeGenAccessor(ctx, row.value, f.name, i.toString, f.dataType, f.nullable, fieldChecks)
+ }
+ ("", row.code, columns)
+
+ case ArrayType(dataType, nullable) =>
+ ("", "", Seq(codeGenAccessor(ctx, data.value, "col", index, dataType, nullable, checks)))
+
+ case MapType(keyType, valueType, valueContainsNull) =>
+ // Materialize the key and the value arrays before we enter the loop.
+ val keyArray = ctx.freshName("keyArray")
+ val valueArray = ctx.freshName("valueArray")
+ val initArrayData =
+ s"""
+ |ArrayData $keyArray = ${data.isNull} ? null : ${data.value}.keyArray();
+ |ArrayData $valueArray = ${data.isNull} ? null : ${data.value}.valueArray();
+ """.stripMargin
+ val values = Seq(
+ codeGenAccessor(ctx, keyArray, "key", index, keyType, nullable = false, checks),
+ codeGenAccessor(ctx, valueArray, "value", index, valueType, valueContainsNull, checks))
+ (initArrayData, "", values)
+ }
+
+ // In case of outer=true we need to make sure the loop is executed at-least once when the
+ // array/map contains no input. We do this by setting the looping index to -1 if there is no
+ // input, evaluation of the array is prevented by a check in the accessor code.
+ val numElements = ctx.freshName("numElements")
+ val init = if (outer) {
+ s"$numElements == 0 ? -1 : 0"
+ } else {
+ "0"
+ }
+ val numOutput = metricTerm(ctx, "numOutputRows")
+ s"""
+ |${data.code}
+ |$initMapData
+ |int $numElements = ${data.isNull} ? 0 : ${data.value}.numElements();
+ |for (int $index = $init; $index < $numElements; $index++) {
+ | $numOutput.add(1);
+ | $updateRowData
+ | ${consume(ctx, input ++ position ++ values)}
+ |}
+ """.stripMargin
+ }
+
+ /**
+ * Generate code for a regular [[TraversableOnce]] returning [[Generator]].
+ */
+ private def codeGenTraversableOnce(
+ ctx: CodegenContext,
+ e: Expression,
+ input: Seq[ExprCode],
+ row: ExprCode): String = {
+
+ // Generate the code for the generator
+ val data = e.genCode(ctx)
+
+ // Generate looping variables.
+ val iterator = ctx.freshName("iterator")
+ val hasNext = ctx.freshName("hasNext")
+ val current = ctx.freshName("row")
+
+ // Add a check if the generate outer flag is true.
+ val checks = optionalCode(outer, s"!$hasNext")
+ val values = e.dataType match {
+ case ArrayType(st: StructType, nullable) =>
+ st.fields.toSeq.zipWithIndex.map { case (f, i) =>
+ codeGenAccessor(ctx, current, f.name, s"$i", f.dataType, f.nullable, checks)
+ }
+ }
+
+ // In case of outer=true we need to make sure the loop is executed at-least-once when the
+ // iterator contains no input. We do this by adding an 'outer' variable which guarantees
+ // execution of the first iteration even if there is no input. Evaluation of the iterator is
+ // prevented by checks in the next() and accessor code.
+ val numOutput = metricTerm(ctx, "numOutputRows")
+ if (outer) {
+ val outerVal = ctx.freshName("outer")
+ s"""
+ |${data.code}
+ |scala.collection.Iterator<InternalRow> $iterator = ${data.value}.toIterator();
+ |boolean $outerVal = true;
+ |while ($iterator.hasNext() || $outerVal) {
+ | $numOutput.add(1);
+ | boolean $hasNext = $iterator.hasNext();
+ | InternalRow $current = (InternalRow)($hasNext? $iterator.next() : null);
+ | $outerVal = false;
+ | ${consume(ctx, input ++ values)}
+ |}
+ """.stripMargin
+ } else {
+ s"""
+ |${data.code}
+ |scala.collection.Iterator<InternalRow> $iterator = ${data.value}.toIterator();
+ |while ($iterator.hasNext()) {
+ | $numOutput.add(1);
+ | InternalRow $current = (InternalRow)($iterator.next());
+ | ${consume(ctx, input ++ values)}
+ |}
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Generate accessor code for ArrayData and InternalRows.
+ */
+ private def codeGenAccessor(
+ ctx: CodegenContext,
+ source: String,
+ name: String,
+ index: String,
+ dt: DataType,
+ nullable: Boolean,
+ initialChecks: Seq[String]): ExprCode = {
+ val value = ctx.freshName(name)
+ val javaType = ctx.javaType(dt)
+ val getter = ctx.getValue(source, dt, index)
+ val checks = initialChecks ++ optionalCode(nullable, s"$source.isNullAt($index)")
+ if (checks.nonEmpty) {
+ val isNull = ctx.freshName("isNull")
+ val code =
+ s"""
+ |boolean $isNull = ${checks.mkString(" || ")};
+ |$javaType $value = $isNull ? ${ctx.defaultValue(dt)} : $getter;
+ """.stripMargin
+ ExprCode(code, isNull, value)
+ } else {
+ ExprCode(s"$javaType $value = $getter;", "false", value)
+ }
+ }
+
+ private def optionalCode(condition: Boolean, code: => String): Seq[String] = {
+ if (condition) Seq(code)
+ else Seq.empty
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
index aedc0a8d6f..f0995ea1d0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
@@ -17,8 +17,12 @@
package org.apache.spark.sql
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Expression, Generator}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{IntegerType, StructType}
class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
import testImplicits._
@@ -202,4 +206,34 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
df.selectExpr("array(struct(a), named_struct('a', b))").selectExpr("inline(*)"),
Row(1) :: Row(2) :: Nil)
}
+
+ test("SPARK-14986: Outer lateral view with empty generate expression") {
+ checkAnswer(
+ sql("select nil from values 1 lateral view outer explode(array()) n as nil"),
+ Row(null) :: Nil
+ )
+ }
+
+ test("outer explode()") {
+ checkAnswer(
+ sql("select * from values 1, 2 lateral view outer explode(array()) a as b"),
+ Row(1, null) :: Row(2, null) :: Nil)
+ }
+
+ test("outer generator()") {
+ spark.sessionState.functionRegistry.registerFunction("empty_gen", _ => EmptyGenerator())
+ checkAnswer(
+ sql("select * from values 1, 2 lateral view outer empty_gen() a as b"),
+ Row(1, null) :: Row(2, null) :: Nil)
+ }
+}
+
+case class EmptyGenerator() extends Generator {
+ override def children: Seq[Expression] = Nil
+ override def elementSchema: StructType = new StructType().add("id", IntegerType)
+ override def eval(input: InternalRow): TraversableOnce[InternalRow] = Seq.empty
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val iteratorClass = classOf[Iterator[_]].getName
+ ev.copy(code = s"$iteratorClass<InternalRow> ${ev.value} = $iteratorClass$$.MODULE$$.empty();")
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 6b517bc70f..a715176d55 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2086,13 +2086,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
}
- test("SPARK-14986: Outer lateral view with empty generate expression") {
- checkAnswer(
- sql("select nil from (select 1 as x ) x lateral view outer explode(array()) n as nil"),
- Row(null) :: Nil
- )
- }
-
test("data source table created in InMemoryCatalog should be able to read/write") {
withTable("tbl") {
sql("CREATE TABLE tbl(i INT, j STRING) USING parquet")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index f26e5e7b69..e8ea7758cf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -17,7 +17,9 @@
package org.apache.spark.sql.execution
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{Column, Dataset, Row}
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.expressions.{Add, Literal, Stack}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.expressions.scalalang.typed
@@ -113,4 +115,32 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0)))
}
+
+ test("generate should be included in WholeStageCodegen") {
+ import org.apache.spark.sql.functions._
+ val ds = spark.range(2).select(
+ col("id"),
+ explode(array(col("id") + 1, col("id") + 2)).as("value"))
+ val plan = ds.queryExecution.executedPlan
+ assert(plan.find(p =>
+ p.isInstanceOf[WholeStageCodegenExec] &&
+ p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[GenerateExec]).isDefined)
+ assert(ds.collect() === Array(Row(0, 1), Row(0, 2), Row(1, 2), Row(1, 3)))
+ }
+
+ test("large stack generator should not use WholeStageCodegen") {
+ def createStackGenerator(rows: Int): SparkPlan = {
+ val id = UnresolvedAttribute("id")
+ val stack = Stack(Literal(rows) +: Seq.tabulate(rows)(i => Add(id, Literal(i))))
+ spark.range(500).select(Column(stack)).queryExecution.executedPlan
+ }
+ val isCodeGenerated: SparkPlan => Boolean = {
+ case WholeStageCodegenExec(_: GenerateExec) => true
+ case _ => false
+ }
+
+ // Only 'stack' generators that produce 50 rows or less are code generated.
+ assert(createStackGenerator(50).find(isCodeGenerated).isDefined)
+ assert(createStackGenerator(100).find(isCodeGenerated).isEmpty)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala
index 470c78120b..01773c238b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala
@@ -102,7 +102,7 @@ class MiscBenchmark extends BenchmarkBase {
}
benchmark.run()
- /**
+ /*
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
@@ -124,7 +124,7 @@ class MiscBenchmark extends BenchmarkBase {
}
benchmark.run()
- /**
+ /*
model name : Westmere E56xx/L56xx/X56xx (Nehalem-C)
collect limit: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
@@ -132,4 +132,99 @@ class MiscBenchmark extends BenchmarkBase {
collect limit 2 millions 3348 / 4005 0.3 3193.3 0.2X
*/
}
+
+ ignore("generate explode") {
+ val N = 1 << 24
+ runBenchmark("generate explode array", N) {
+ val df = sparkSession.range(N).selectExpr(
+ "id as key",
+ "array(rand(), rand(), rand(), rand(), rand()) as values")
+ df.selectExpr("key", "explode(values) value").count()
+ }
+
+ /*
+ Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6
+ Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz
+
+ generate explode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ ------------------------------------------------------------------------------------------------
+ generate explode array wholestage off 6920 / 7129 2.4 412.5 1.0X
+ generate explode array wholestage on 623 / 646 26.9 37.1 11.1X
+ */
+
+ runBenchmark("generate explode map", N) {
+ val df = sparkSession.range(N).selectExpr(
+ "id as key",
+ "map('a', rand(), 'b', rand(), 'c', rand(), 'd', rand(), 'e', rand()) pairs")
+ df.selectExpr("key", "explode(pairs) as (k, v)").count()
+ }
+
+ /*
+ Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6
+ Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz
+
+ generate explode map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ ------------------------------------------------------------------------------------------------
+ generate explode map wholestage off 11978 / 11993 1.4 714.0 1.0X
+ generate explode map wholestage on 866 / 919 19.4 51.6 13.8X
+ */
+
+ runBenchmark("generate posexplode array", N) {
+ val df = sparkSession.range(N).selectExpr(
+ "id as key",
+ "array(rand(), rand(), rand(), rand(), rand()) as values")
+ df.selectExpr("key", "posexplode(values) as (idx, value)").count()
+ }
+
+ /*
+ Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6
+ Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz
+
+ generate posexplode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ ------------------------------------------------------------------------------------------------
+ generate posexplode array wholestage off 7502 / 7513 2.2 447.1 1.0X
+ generate posexplode array wholestage on 617 / 623 27.2 36.8 12.2X
+ */
+
+ runBenchmark("generate inline array", N) {
+ val df = sparkSession.range(N).selectExpr(
+ "id as key",
+ "array((rand(), rand()), (rand(), rand()), (rand(), 0.0d)) as values")
+ df.selectExpr("key", "inline(values) as (r1, r2)").count()
+ }
+
+ /*
+ Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6
+ Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz
+
+ generate inline array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ ------------------------------------------------------------------------------------------------
+ generate inline array wholestage off 6901 / 6928 2.4 411.3 1.0X
+ generate inline array wholestage on 1001 / 1010 16.8 59.7 6.9X
+ */
+ }
+
+ ignore("generate regular generator") {
+ val N = 1 << 24
+ runBenchmark("generate stack", N) {
+ val df = sparkSession.range(N).selectExpr(
+ "id as key",
+ "id % 2 as t1",
+ "id % 3 as t2",
+ "id % 5 as t3",
+ "id % 7 as t4",
+ "id % 13 as t5")
+ df.selectExpr("key", "stack(4, t1, t2, t3, t4, t5)").count()
+ }
+
+ /*
+ Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6
+ Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz
+
+ generate stack: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ ------------------------------------------------------------------------------------------------
+ generate stack wholestage off 12953 / 13070 1.3 772.1 1.0X
+ generate stack wholestage on 836 / 847 20.1 49.8 15.5X
+ */
+ }
}