aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala110
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala16
2 files changed, 101 insertions, 25 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index d042bfb63d..6c38f4998e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -17,10 +17,12 @@
package org.apache.spark.sql.catalyst.expressions
+import scala.collection.mutable
+
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
@@ -60,6 +62,26 @@ trait Generator extends Expression {
* rows can be made here.
*/
def terminate(): TraversableOnce[InternalRow] = Nil
+
+ /**
+ * Check if this generator supports code generation.
+ */
+ def supportCodegen: Boolean = !isInstanceOf[CodegenFallback]
+}
+
+/**
+ * A collection producing [[Generator]]. This trait provides a different path for code generation,
+ * by allowing code generation to return either an [[ArrayData]] or a [[MapData]] object.
+ */
+trait CollectionGenerator extends Generator {
+ /** The position of an element within the collection should also be returned. */
+ def position: Boolean
+
+ /** Rows will be inlined during generation. */
+ def inline: Boolean
+
+ /** The type of the returned collection object. */
+ def collectionType: DataType = dataType
}
/**
@@ -77,7 +99,9 @@ case class UserDefinedGenerator(
private def initializeConverters(): Unit = {
inputRow = new InterpretedProjection(children)
convertToScala = {
- val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true)))
+ val inputSchema = StructType(children.map { e =>
+ StructField(e.simpleString, e.dataType, nullable = true)
+ })
CatalystTypeConverters.createToScalaConverter(inputSchema)
}.asInstanceOf[InternalRow => Row]
}
@@ -109,8 +133,7 @@ case class UserDefinedGenerator(
1 2
3 NULL
""")
-case class Stack(children: Seq[Expression])
- extends Expression with Generator with CodegenFallback {
+case class Stack(children: Seq[Expression]) extends Generator {
private lazy val numRows = children.head.eval().asInstanceOf[Int]
private lazy val numFields = Math.ceil((children.length - 1.0) / numRows).toInt
@@ -149,21 +172,50 @@ case class Stack(children: Seq[Expression])
InternalRow(fields: _*)
}
}
+
+
+ /**
+ * Only support code generation when stack produces 50 rows or less.
+ */
+ override def supportCodegen: Boolean = numRows <= 50
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ // Rows - we write these into an array.
+ val rowData = ctx.freshName("rows")
+ ctx.addMutableState("InternalRow[]", rowData, s"this.$rowData = new InternalRow[$numRows];")
+ val values = children.tail
+ val dataTypes = values.take(numFields).map(_.dataType)
+ val code = ctx.splitExpressions(ctx.INPUT_ROW, Seq.tabulate(numRows) { row =>
+ val fields = Seq.tabulate(numFields) { col =>
+ val index = row * numFields + col
+ if (index < values.length) values(index) else Literal(null, dataTypes(col))
+ }
+ val eval = CreateStruct(fields).genCode(ctx)
+ s"${eval.code}\nthis.$rowData[$row] = ${eval.value};"
+ })
+
+ // Create the collection.
+ val wrapperClass = classOf[mutable.WrappedArray[_]].getName
+ ctx.addMutableState(
+ s"$wrapperClass<InternalRow>",
+ ev.value,
+ s"this.${ev.value} = $wrapperClass$$.MODULE$$.make(this.$rowData);")
+ ev.copy(code = code, isNull = "false")
+ }
}
/**
- * A base class for Explode and PosExplode
+ * A base class for [[Explode]] and [[PosExplode]].
*/
-abstract class ExplodeBase(child: Expression, position: Boolean)
- extends UnaryExpression with Generator with CodegenFallback with Serializable {
+abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with Serializable {
+ override val inline: Boolean = false
- override def checkInputDataTypes(): TypeCheckResult = {
- if (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) {
+ override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
+ case _: ArrayType | _: MapType =>
TypeCheckResult.TypeCheckSuccess
- } else {
+ case _ =>
TypeCheckResult.TypeCheckFailure(
s"input to function explode should be array or map type, not ${child.dataType}")
- }
}
// hive-compatible default alias for explode function ("col" for array, "key", "value" for map)
@@ -171,7 +223,7 @@ abstract class ExplodeBase(child: Expression, position: Boolean)
case ArrayType(et, containsNull) =>
if (position) {
new StructType()
- .add("pos", IntegerType, false)
+ .add("pos", IntegerType, nullable = false)
.add("col", et, containsNull)
} else {
new StructType()
@@ -180,12 +232,12 @@ abstract class ExplodeBase(child: Expression, position: Boolean)
case MapType(kt, vt, valueContainsNull) =>
if (position) {
new StructType()
- .add("pos", IntegerType, false)
- .add("key", kt, false)
+ .add("pos", IntegerType, nullable = false)
+ .add("key", kt, nullable = false)
.add("value", vt, valueContainsNull)
} else {
new StructType()
- .add("key", kt, false)
+ .add("key", kt, nullable = false)
.add("value", vt, valueContainsNull)
}
}
@@ -218,6 +270,12 @@ abstract class ExplodeBase(child: Expression, position: Boolean)
}
}
}
+
+ override def collectionType: DataType = child.dataType
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ child.genCode(ctx)
+ }
}
/**
@@ -239,7 +297,9 @@ abstract class ExplodeBase(child: Expression, position: Boolean)
20
""")
// scalastyle:on line.size.limit
-case class Explode(child: Expression) extends ExplodeBase(child, position = false)
+case class Explode(child: Expression) extends ExplodeBase {
+ override val position: Boolean = false
+}
/**
* Given an input array produces a sequence of rows for each position and value in the array.
@@ -260,7 +320,9 @@ case class Explode(child: Expression) extends ExplodeBase(child, position = fals
1 20
""")
// scalastyle:on line.size.limit
-case class PosExplode(child: Expression) extends ExplodeBase(child, position = true)
+case class PosExplode(child: Expression) extends ExplodeBase {
+ override val position = true
+}
/**
* Explodes an array of structs into a table.
@@ -273,10 +335,12 @@ case class PosExplode(child: Expression) extends ExplodeBase(child, position = t
1 a
2 b
""")
-case class Inline(child: Expression) extends UnaryExpression with Generator with CodegenFallback {
+case class Inline(child: Expression) extends UnaryExpression with CollectionGenerator {
+ override val inline: Boolean = true
+ override val position: Boolean = false
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
- case ArrayType(et, _) if et.isInstanceOf[StructType] =>
+ case ArrayType(st: StructType, _) =>
TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(
@@ -284,9 +348,11 @@ case class Inline(child: Expression) extends UnaryExpression with Generator with
}
override def elementSchema: StructType = child.dataType match {
- case ArrayType(et : StructType, _) => et
+ case ArrayType(st: StructType, _) => st
}
+ override def collectionType: DataType = child.dataType
+
private lazy val numFields = elementSchema.fields.length
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
@@ -298,4 +364,8 @@ case class Inline(child: Expression) extends UnaryExpression with Generator with
yield inputArray.getStruct(i, numFields)
}
}
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ child.genCode(ctx)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
index 1e39b24fe8..2db2a043e5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
@@ -17,7 +17,8 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.types.{DataType, IntegerType}
class SubexpressionEliminationSuite extends SparkFunSuite {
test("Semantic equals and hash") {
@@ -162,13 +163,18 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
test("Children of CodegenFallback") {
val one = Literal(1)
val two = Add(one, one)
- val explode = Explode(two)
- val add = Add(two, explode)
+ val fallback = CodegenFallbackExpression(two)
+ val add = Add(two, fallback)
- var equivalence = new EquivalentExpressions
+ val equivalence = new EquivalentExpressions
equivalence.addExprTree(add, true)
- // the `two` inside `explode` should not be added
+ // the `two` inside `fallback` should not be added
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)
assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode
}
}
+
+case class CodegenFallbackExpression(child: Expression)
+ extends UnaryExpression with CodegenFallback {
+ override def dataType: DataType = child.dataType
+}