aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala48
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala43
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala52
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala122
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala82
6 files changed, 207 insertions, 142 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 7b41c9a3f3..c21f4d626a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen
import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
import scala.language.existentials
import com.google.common.cache.{CacheBuilder, CacheLoader}
@@ -265,6 +266,45 @@ class CodeGenContext {
def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt)
def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt))
+
+ /**
+ * Splits the generated code of expressions into multiple functions, because function has
+ * 64kb code size limit in JVM
+ *
+ * @param row the variable name of row that is used by expressions
+ */
+ def splitExpressions(row: String, expressions: Seq[String]): String = {
+ val blocks = new ArrayBuffer[String]()
+ val blockBuilder = new StringBuilder()
+ for (code <- expressions) {
+ // We can't know how many byte code will be generated, so use the number of bytes as limit
+ if (blockBuilder.length > 64 * 1000) {
+ blocks.append(blockBuilder.toString())
+ blockBuilder.clear()
+ }
+ blockBuilder.append(code)
+ }
+ blocks.append(blockBuilder.toString())
+
+ if (blocks.length == 1) {
+ // inline execution if only one block
+ blocks.head
+ } else {
+ val apply = freshName("apply")
+ val functions = blocks.zipWithIndex.map { case (body, i) =>
+ val name = s"${apply}_$i"
+ val code = s"""
+ |private void $name(InternalRow $row) {
+ | $body
+ |}
+ """.stripMargin
+ addNewFunction(name, code)
+ name
+ }
+
+ functions.map(name => s"$name($row);").mkString("\n")
+ }
+ }
}
/**
@@ -289,15 +329,15 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
protected def declareMutableStates(ctx: CodeGenContext): String = {
ctx.mutableStates.map { case (javaType, variableName, _) =>
s"private $javaType $variableName;"
- }.mkString
+ }.mkString("\n")
}
protected def initMutableStates(ctx: CodeGenContext): String = {
- ctx.mutableStates.map(_._3).mkString
+ ctx.mutableStates.map(_._3).mkString("\n")
}
protected def declareAddedFunctions(ctx: CodeGenContext): String = {
- ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString
+ ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString("\n")
}
/**
@@ -328,6 +368,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
private[this] def doCompile(code: String): GeneratedClass = {
val evaluator = new ClassBodyEvaluator()
evaluator.setParentClassLoader(getClass.getClassLoader)
+ // Cannot be under package codegen, or fail with java.lang.InstantiationException
+ evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass")
evaluator.setDefaultImports(Array(
classOf[PlatformDependent].getName,
classOf[InternalRow].getName,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index ac58423cd8..b4d4df8934 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -40,7 +40,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
protected def create(expressions: Seq[Expression]): (() => MutableProjection) = {
val ctx = newCodeGenContext()
- val projectionCode = expressions.zipWithIndex.map {
+ val projectionCodes = expressions.zipWithIndex.map {
case (NoOp, _) => ""
case (e, i) =>
val evaluationCode = e.gen(ctx)
@@ -65,49 +65,21 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
"""
}
}
- // collect projections into blocks as function has 64kb codesize limit in JVM
- val projectionBlocks = new ArrayBuffer[String]()
- val blockBuilder = new StringBuilder()
- for (projection <- projectionCode) {
- if (blockBuilder.length > 16 * 1000) {
- projectionBlocks.append(blockBuilder.toString())
- blockBuilder.clear()
- }
- blockBuilder.append(projection)
- }
- projectionBlocks.append(blockBuilder.toString())
-
- val (projectionFuns, projectionCalls) = {
- // inline execution if codesize limit was not broken
- if (projectionBlocks.length == 1) {
- ("", projectionBlocks.head)
- } else {
- (
- projectionBlocks.zipWithIndex.map { case (body, i) =>
- s"""
- |private void apply$i(InternalRow i) {
- | $body
- |}
- """.stripMargin
- }.mkString,
- projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n")
- )
- }
- }
+ val allProjections = ctx.splitExpressions("i", projectionCodes)
val code = s"""
public Object generate($exprType[] expr) {
- return new SpecificProjection(expr);
+ return new SpecificMutableProjection(expr);
}
- class SpecificProjection extends ${classOf[BaseMutableProjection].getName} {
+ class SpecificMutableProjection extends ${classOf[BaseMutableProjection].getName} {
private $exprType[] expressions;
private $mutableRowType mutableRow;
${declareMutableStates(ctx)}
${declareAddedFunctions(ctx)}
- public SpecificProjection($exprType[] expr) {
+ public SpecificMutableProjection($exprType[] expr) {
expressions = expr;
mutableRow = new $genericMutableRowType(${expressions.size});
${initMutableStates(ctx)}
@@ -123,12 +95,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
return (InternalRow) mutableRow;
}
- $projectionFuns
-
public Object apply(Object _i) {
InternalRow i = (InternalRow) _i;
- $projectionCalls
-
+ $allProjections
return mutableRow;
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index ef08ddf041..7ad352d7ce 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.expressions.codegen
-import scala.collection.mutable.ArrayBuffer
-
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
import org.apache.spark.sql.types._
@@ -43,6 +41,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val tmp = ctx.freshName("tmp")
val output = ctx.freshName("safeRow")
val values = ctx.freshName("values")
+ // These expressions could be splitted into multiple functions
+ ctx.addMutableState("Object[]", values, s"this.$values = null;")
+
val rowClass = classOf[GenericInternalRow].getName
val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) =>
@@ -53,12 +54,12 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
$values[$i] = ${converter.primitive};
}
"""
- }.mkString("\n")
-
+ }
+ val allFields = ctx.splitExpressions(tmp, fieldWriters)
val code = s"""
final InternalRow $tmp = $input;
- final Object[] $values = new Object[${schema.length}];
- $fieldWriters
+ this.$values = new Object[${schema.length}];
+ $allFields
final InternalRow $output = new $rowClass($values);
"""
@@ -128,7 +129,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
protected def create(expressions: Seq[Expression]): Projection = {
val ctx = newCodeGenContext()
- val projectionCode = expressions.zipWithIndex.map {
+ val expressionCodes = expressions.zipWithIndex.map {
case (NoOp, _) => ""
case (e, i) =>
val evaluationCode = e.gen(ctx)
@@ -143,36 +144,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
}
"""
}
- // collect projections into blocks as function has 64kb codesize limit in JVM
- val projectionBlocks = new ArrayBuffer[String]()
- val blockBuilder = new StringBuilder()
- for (projection <- projectionCode) {
- if (blockBuilder.length > 16 * 1000) {
- projectionBlocks.append(blockBuilder.toString())
- blockBuilder.clear()
- }
- blockBuilder.append(projection)
- }
- projectionBlocks.append(blockBuilder.toString())
-
- val (projectionFuns, projectionCalls) = {
- // inline it if we have only one block
- if (projectionBlocks.length == 1) {
- ("", projectionBlocks.head)
- } else {
- (
- projectionBlocks.zipWithIndex.map { case (body, i) =>
- s"""
- |private void apply$i(InternalRow i) {
- | $body
- |}
- """.stripMargin
- }.mkString,
- projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n")
- )
- }
- }
-
+ val allExpressions = ctx.splitExpressions("i", expressionCodes)
val code = s"""
public Object generate($exprType[] expr) {
return new SpecificSafeProjection(expr);
@@ -183,6 +155,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
private $exprType[] expressions;
private $mutableRowType mutableRow;
${declareMutableStates(ctx)}
+ ${declareAddedFunctions(ctx)}
public SpecificSafeProjection($exprType[] expr) {
expressions = expr;
@@ -190,12 +163,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
${initMutableStates(ctx)}
}
- $projectionFuns
-
public Object apply(Object _i) {
InternalRow i = (InternalRow) _i;
- $projectionCalls
-
+ $allExpressions
return mutableRow;
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index d8912df694..29f6a7b981 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.PlatformDependent
/**
* Generates a [[Projection]] that returns an [[UnsafeRow]].
@@ -41,8 +40,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
private val ArrayWriter = classOf[UnsafeRowWriters.ArrayWriter].getName
private val MapWriter = classOf[UnsafeRowWriters.MapWriter].getName
- private val PlatformDependent = classOf[PlatformDependent].getName
-
/** Returns true iff we support this data type. */
def canSupport(dataType: DataType): Boolean = dataType match {
case NullType => true
@@ -56,19 +53,19 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
def genAdditionalSize(dt: DataType, ev: GeneratedExpressionCode): String = dt match {
case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
- s" + $DecimalWriter.getSize(${ev.primitive})"
+ s"$DecimalWriter.getSize(${ev.primitive})"
case StringType =>
- s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))"
+ s"${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive})"
case BinaryType =>
- s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))"
+ s"${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive})"
case CalendarIntervalType =>
- s" + (${ev.isNull} ? 0 : 16)"
+ s"${ev.isNull} ? 0 : 16"
case _: StructType =>
- s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))"
+ s"${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive})"
case _: ArrayType =>
- s" + (${ev.isNull} ? 0 : $ArrayWriter.getSize(${ev.primitive}))"
+ s"${ev.isNull} ? 0 : $ArrayWriter.getSize(${ev.primitive})"
case _: MapType =>
- s" + (${ev.isNull} ? 0 : $MapWriter.getSize(${ev.primitive}))"
+ s"${ev.isNull} ? 0 : $MapWriter.getSize(${ev.primitive})"
case _ => ""
}
@@ -125,64 +122,69 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
*/
private def createCodeForStruct(
ctx: CodeGenContext,
+ row: String,
inputs: Seq[GeneratedExpressionCode],
inputTypes: Seq[DataType]): GeneratedExpressionCode = {
+ val fixedSize = 8 * inputTypes.length + UnsafeRow.calculateBitSetWidthInBytes(inputTypes.length)
+
val output = ctx.freshName("convertedStruct")
- ctx.addMutableState("UnsafeRow", output, s"$output = new UnsafeRow();")
+ ctx.addMutableState("UnsafeRow", output, s"this.$output = new UnsafeRow();")
val buffer = ctx.freshName("buffer")
- ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];")
- val numBytes = ctx.freshName("numBytes")
+ ctx.addMutableState("byte[]", buffer, s"this.$buffer = new byte[$fixedSize];")
val cursor = ctx.freshName("cursor")
+ ctx.addMutableState("int", cursor, s"this.$cursor = 0;")
+ val tmp = ctx.freshName("tmpBuffer")
- val convertedFields = inputTypes.zip(inputs).map { case (dt, input) =>
- createConvertCode(ctx, input, dt)
- }
-
- val fixedSize = 8 * inputTypes.length + UnsafeRow.calculateBitSetWidthInBytes(inputTypes.length)
- val additionalSize = inputTypes.zip(convertedFields).map { case (dt, ev) =>
- genAdditionalSize(dt, ev)
- }.mkString("")
-
- val fieldWriters = inputTypes.zip(convertedFields).zipWithIndex.map { case ((dt, ev), i) =>
- val update = genFieldWriter(ctx, dt, ev, output, i, cursor)
- if (dt.isInstanceOf[DecimalType]) {
- // Can't call setNullAt() for DecimalType
+ val convertedFields = inputTypes.zip(inputs).zipWithIndex.map { case ((dt, input), i) =>
+ val ev = createConvertCode(ctx, input, dt)
+ val growBuffer = if (!UnsafeRow.isFixedLength(dt)) {
+ val numBytes = ctx.freshName("numBytes")
s"""
+ int $numBytes = $cursor + (${genAdditionalSize(dt, ev)});
+ if ($buffer.length < $numBytes) {
+ // This will not happen frequently, because the buffer is re-used.
+ byte[] $tmp = new byte[$numBytes * 2];
+ PlatformDependent.copyMemory($buffer, PlatformDependent.BYTE_ARRAY_OFFSET,
+ $tmp, PlatformDependent.BYTE_ARRAY_OFFSET, $buffer.length);
+ $buffer = $tmp;
+ }
+ $output.pointTo($buffer, PlatformDependent.BYTE_ARRAY_OFFSET,
+ ${inputTypes.length}, $numBytes);
+ """
+ } else {
+ ""
+ }
+ val update = dt match {
+ case dt: DecimalType if dt.precision > Decimal.MAX_LONG_DIGITS =>
+ // Can't call setNullAt() for DecimalType
+ s"""
if (${ev.isNull}) {
- $cursor += $DecimalWriter.write($output, $i, $cursor, null);
+ $cursor += $DecimalWriter.write($output, $i, $cursor, null);
} else {
- $update;
+ ${genFieldWriter(ctx, dt, ev, output, i, cursor)};
}
"""
- } else {
- s"""
+ case _ =>
+ s"""
if (${ev.isNull}) {
$output.setNullAt($i);
} else {
- $update;
+ ${genFieldWriter(ctx, dt, ev, output, i, cursor)};
}
"""
}
- }.mkString("\n")
+ s"""
+ ${ev.code}
+ $growBuffer
+ $update
+ """
+ }
val code = s"""
- ${convertedFields.map(_.code).mkString("\n")}
-
- final int $numBytes = $fixedSize $additionalSize;
- if ($numBytes > $buffer.length) {
- $buffer = new byte[$numBytes];
- }
-
- $output.pointTo(
- $buffer,
- $PlatformDependent.BYTE_ARRAY_OFFSET,
- ${inputTypes.length},
- $numBytes);
-
- int $cursor = $fixedSize;
-
- $fieldWriters
+ $cursor = $fixedSize;
+ $output.pointTo($buffer, PlatformDependent.BYTE_ARRAY_OFFSET, ${inputTypes.length}, $cursor);
+ ${ctx.splitExpressions(row, convertedFields)}
"""
GeneratedExpressionCode(code, "false", output)
}
@@ -265,17 +267,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
// Should we do word align?
val elementSize = elementType.defaultSize
s"""
- $PlatformDependent.UNSAFE.put${ctx.primitiveTypeName(elementType)}(
+ PlatformDependent.UNSAFE.put${ctx.primitiveTypeName(elementType)}(
$buffer,
- $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor,
+ PlatformDependent.BYTE_ARRAY_OFFSET + $cursor,
${convertedElement.primitive});
$cursor += $elementSize;
"""
case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
s"""
- $PlatformDependent.UNSAFE.putLong(
+ PlatformDependent.UNSAFE.putLong(
$buffer,
- $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor,
+ PlatformDependent.BYTE_ARRAY_OFFSET + $cursor,
${convertedElement.primitive}.toUnscaledLong());
$cursor += 8;
"""
@@ -284,7 +286,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
s"""
$cursor += $writer.write(
$buffer,
- $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor,
+ PlatformDependent.BYTE_ARRAY_OFFSET + $cursor,
$elements[$index]);
"""
}
@@ -318,14 +320,14 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
for (int $index = 0; $index < $numElements; $index++) {
if ($checkNull) {
// If element is null, write the negative value address into offset region.
- $PlatformDependent.UNSAFE.putInt(
+ PlatformDependent.UNSAFE.putInt(
$buffer,
- $PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index,
+ PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index,
-$cursor);
} else {
- $PlatformDependent.UNSAFE.putInt(
+ PlatformDependent.UNSAFE.putInt(
$buffer,
- $PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index,
+ PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index,
$cursor);
$writeElement
@@ -334,7 +336,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
$output.pointTo(
$buffer,
- $PlatformDependent.BYTE_ARRAY_OFFSET,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
$numElements,
$numBytes);
}
@@ -400,7 +402,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val fieldIsNull = s"$tmp.isNullAt($i)"
GeneratedExpressionCode("", fieldIsNull, getFieldCode)
}
- val converter = createCodeForStruct(ctx, fieldEvals, fieldTypes)
+ val converter = createCodeForStruct(ctx, tmp, fieldEvals, fieldTypes)
val code = s"""
${input.code}
UnsafeRow $output = null;
@@ -427,7 +429,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
def createCode(ctx: CodeGenContext, expressions: Seq[Expression]): GeneratedExpressionCode = {
val exprEvals = expressions.map(e => e.gen(ctx))
val exprTypes = expressions.map(_.dataType)
- createCodeForStruct(ctx, exprEvals, exprTypes)
+ createCodeForStruct(ctx, "i", exprEvals, exprTypes)
}
protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
index 30b51dd83f..8aaa5b4300 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
@@ -155,7 +155,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
|$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32));
""".stripMargin
}
- }.mkString
+ }.mkString("\n")
// ------------------------ Finally, put everything together --------------------------- //
val code = s"""
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
new file mode 100644
index 0000000000..8c7ee8720f
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.{StringType, IntegerType, StructField, StructType}
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * A test suite for generated projections
+ */
+class GeneratedProjectionSuite extends SparkFunSuite {
+
+ test("generated projections on wider table") {
+ val N = 1000
+ val wideRow1 = new GenericInternalRow((1 to N).toArray[Any])
+ val schema1 = StructType((1 to N).map(i => StructField("", IntegerType)))
+ val wideRow2 = new GenericInternalRow(
+ (1 to N).map(i => UTF8String.fromString(i.toString)).toArray[Any])
+ val schema2 = StructType((1 to N).map(i => StructField("", StringType)))
+ val joined = new JoinedRow(wideRow1, wideRow2)
+ val joinedSchema = StructType(schema1 ++ schema2)
+ val nested = new JoinedRow(InternalRow(joined, joined), joined)
+ val nestedSchema = StructType(
+ Seq(StructField("", joinedSchema), StructField("", joinedSchema)) ++ joinedSchema)
+
+ // test generated UnsafeProjection
+ val unsafeProj = UnsafeProjection.create(nestedSchema)
+ val unsafe: UnsafeRow = unsafeProj(nested)
+ (0 until N).foreach { i =>
+ val s = UTF8String.fromString((i + 1).toString)
+ assert(i + 1 === unsafe.getInt(i + 2))
+ assert(s === unsafe.getUTF8String(i + 2 + N))
+ assert(i + 1 === unsafe.getStruct(0, N * 2).getInt(i))
+ assert(s === unsafe.getStruct(0, N * 2).getUTF8String(i + N))
+ assert(i + 1 === unsafe.getStruct(1, N * 2).getInt(i))
+ assert(s === unsafe.getStruct(1, N * 2).getUTF8String(i + N))
+ }
+
+ // test generated SafeProjection
+ val safeProj = FromUnsafeProjection(nestedSchema)
+ val result = safeProj(unsafe)
+ // Can't compare GenericInternalRow with JoinedRow directly
+ (0 until N).foreach { i =>
+ val r = i + 1
+ val s = UTF8String.fromString((i + 1).toString)
+ assert(r === result.getInt(i + 2))
+ assert(s === result.getUTF8String(i + 2 + N))
+ assert(r === result.getStruct(0, N * 2).getInt(i))
+ assert(s === result.getStruct(0, N * 2).getUTF8String(i + N))
+ assert(r === result.getStruct(1, N * 2).getInt(i))
+ assert(s === result.getStruct(1, N * 2).getUTF8String(i + N))
+ }
+
+ // test generated MutableProjection
+ val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) =>
+ BoundReference(i, f.dataType, true)
+ }
+ val mutableProj = GenerateMutableProjection.generate(exprs)()
+ val row1 = mutableProj(result)
+ assert(result === row1)
+ val row2 = mutableProj(result)
+ assert(result === row2)
+ }
+}