diff options
Diffstat (limited to 'sql/catalyst')
2 files changed, 101 insertions, 25 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index d042bfb63d..6c38f4998e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.mutable + import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ @@ -60,6 +62,26 @@ trait Generator extends Expression { * rows can be made here. */ def terminate(): TraversableOnce[InternalRow] = Nil + + /** + * Check if this generator supports code generation. + */ + def supportCodegen: Boolean = !isInstanceOf[CodegenFallback] +} + +/** + * A collection producing [[Generator]]. This trait provides a different path for code generation, + * by allowing code generation to return either an [[ArrayData]] or a [[MapData]] object. + */ +trait CollectionGenerator extends Generator { + /** The position of an element within the collection should also be returned. */ + def position: Boolean + + /** Rows will be inlined during generation. */ + def inline: Boolean + + /** The type of the returned collection object. */ + def collectionType: DataType = dataType } /** @@ -77,7 +99,9 @@ case class UserDefinedGenerator( private def initializeConverters(): Unit = { inputRow = new InterpretedProjection(children) convertToScala = { - val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) + val inputSchema = StructType(children.map { e => + StructField(e.simpleString, e.dataType, nullable = true) + }) CatalystTypeConverters.createToScalaConverter(inputSchema) }.asInstanceOf[InternalRow => Row] } @@ -109,8 +133,7 @@ case class UserDefinedGenerator( 1 2 3 NULL """) -case class Stack(children: Seq[Expression]) - extends Expression with Generator with CodegenFallback { +case class Stack(children: Seq[Expression]) extends Generator { private lazy val numRows = children.head.eval().asInstanceOf[Int] private lazy val numFields = Math.ceil((children.length - 1.0) / numRows).toInt @@ -149,21 +172,50 @@ case class Stack(children: Seq[Expression]) InternalRow(fields: _*) } } + + + /** + * Only support code generation when stack produces 50 rows or less. + */ + override def supportCodegen: Boolean = numRows <= 50 + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + // Rows - we write these into an array. + val rowData = ctx.freshName("rows") + ctx.addMutableState("InternalRow[]", rowData, s"this.$rowData = new InternalRow[$numRows];") + val values = children.tail + val dataTypes = values.take(numFields).map(_.dataType) + val code = ctx.splitExpressions(ctx.INPUT_ROW, Seq.tabulate(numRows) { row => + val fields = Seq.tabulate(numFields) { col => + val index = row * numFields + col + if (index < values.length) values(index) else Literal(null, dataTypes(col)) + } + val eval = CreateStruct(fields).genCode(ctx) + s"${eval.code}\nthis.$rowData[$row] = ${eval.value};" + }) + + // Create the collection. + val wrapperClass = classOf[mutable.WrappedArray[_]].getName + ctx.addMutableState( + s"$wrapperClass<InternalRow>", + ev.value, + s"this.${ev.value} = $wrapperClass$$.MODULE$$.make(this.$rowData);") + ev.copy(code = code, isNull = "false") + } } /** - * A base class for Explode and PosExplode + * A base class for [[Explode]] and [[PosExplode]]. */ -abstract class ExplodeBase(child: Expression, position: Boolean) - extends UnaryExpression with Generator with CodegenFallback with Serializable { +abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with Serializable { + override val inline: Boolean = false - override def checkInputDataTypes(): TypeCheckResult = { - if (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) { + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case _: ArrayType | _: MapType => TypeCheckResult.TypeCheckSuccess - } else { + case _ => TypeCheckResult.TypeCheckFailure( s"input to function explode should be array or map type, not ${child.dataType}") - } } // hive-compatible default alias for explode function ("col" for array, "key", "value" for map) @@ -171,7 +223,7 @@ abstract class ExplodeBase(child: Expression, position: Boolean) case ArrayType(et, containsNull) => if (position) { new StructType() - .add("pos", IntegerType, false) + .add("pos", IntegerType, nullable = false) .add("col", et, containsNull) } else { new StructType() @@ -180,12 +232,12 @@ abstract class ExplodeBase(child: Expression, position: Boolean) case MapType(kt, vt, valueContainsNull) => if (position) { new StructType() - .add("pos", IntegerType, false) - .add("key", kt, false) + .add("pos", IntegerType, nullable = false) + .add("key", kt, nullable = false) .add("value", vt, valueContainsNull) } else { new StructType() - .add("key", kt, false) + .add("key", kt, nullable = false) .add("value", vt, valueContainsNull) } } @@ -218,6 +270,12 @@ abstract class ExplodeBase(child: Expression, position: Boolean) } } } + + override def collectionType: DataType = child.dataType + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + child.genCode(ctx) + } } /** @@ -239,7 +297,9 @@ abstract class ExplodeBase(child: Expression, position: Boolean) 20 """) // scalastyle:on line.size.limit -case class Explode(child: Expression) extends ExplodeBase(child, position = false) +case class Explode(child: Expression) extends ExplodeBase { + override val position: Boolean = false +} /** * Given an input array produces a sequence of rows for each position and value in the array. @@ -260,7 +320,9 @@ case class Explode(child: Expression) extends ExplodeBase(child, position = fals 1 20 """) // scalastyle:on line.size.limit -case class PosExplode(child: Expression) extends ExplodeBase(child, position = true) +case class PosExplode(child: Expression) extends ExplodeBase { + override val position = true +} /** * Explodes an array of structs into a table. @@ -273,10 +335,12 @@ case class PosExplode(child: Expression) extends ExplodeBase(child, position = t 1 a 2 b """) -case class Inline(child: Expression) extends UnaryExpression with Generator with CodegenFallback { +case class Inline(child: Expression) extends UnaryExpression with CollectionGenerator { + override val inline: Boolean = true + override val position: Boolean = false override def checkInputDataTypes(): TypeCheckResult = child.dataType match { - case ArrayType(et, _) if et.isInstanceOf[StructType] => + case ArrayType(st: StructType, _) => TypeCheckResult.TypeCheckSuccess case _ => TypeCheckResult.TypeCheckFailure( @@ -284,9 +348,11 @@ case class Inline(child: Expression) extends UnaryExpression with Generator with } override def elementSchema: StructType = child.dataType match { - case ArrayType(et : StructType, _) => et + case ArrayType(st: StructType, _) => st } + override def collectionType: DataType = child.dataType + private lazy val numFields = elementSchema.fields.length override def eval(input: InternalRow): TraversableOnce[InternalRow] = { @@ -298,4 +364,8 @@ case class Inline(child: Expression) extends UnaryExpression with Generator with yield inputArray.getStruct(i, numFields) } } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + child.genCode(ctx) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index 1e39b24fe8..2db2a043e5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.types.{DataType, IntegerType} class SubexpressionEliminationSuite extends SparkFunSuite { test("Semantic equals and hash") { @@ -162,13 +163,18 @@ class SubexpressionEliminationSuite extends SparkFunSuite { test("Children of CodegenFallback") { val one = Literal(1) val two = Add(one, one) - val explode = Explode(two) - val add = Add(two, explode) + val fallback = CodegenFallbackExpression(two) + val add = Add(two, fallback) - var equivalence = new EquivalentExpressions + val equivalence = new EquivalentExpressions equivalence.addExprTree(add, true) - // the `two` inside `explode` should not be added + // the `two` inside `fallback` should not be added assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0) assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode } } + +case class CodegenFallbackExpression(child: Expression) + extends UnaryExpression with CodegenFallback { + override def dataType: DataType = child.dataType +} |