aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-08-10 13:52:18 -0700
committerReynold Xin <rxin@databricks.com>2015-08-10 13:52:18 -0700
commitfe2fb7fb7189d183a4273ad27514af4b6b461f26 (patch)
treec9983b34e92572f33bca638e5fe2dfc5af3e6c86
parent40ed2af587cedadc6e5249031857a922b3b234ca (diff)
downloadspark-fe2fb7fb7189d183a4273ad27514af4b6b461f26.tar.gz
spark-fe2fb7fb7189d183a4273ad27514af4b6b461f26.tar.bz2
spark-fe2fb7fb7189d183a4273ad27514af4b6b461f26.zip
[SPARK-9620] [SQL] generated UnsafeProjection should support many columns or large exressions
Currently, generated UnsafeProjection can reach 64k byte code limit of Java. This patch will split the generated expressions into multiple functions, to avoid the limitation. After this patch, we can work well with table that have up to 64k columns (hit max number of constants limit in Java), it should be enough in practice. cc rxin Author: Davies Liu <davies@databricks.com> Closes #8044 from davies/wider_table and squashes the following commits: 9192e6c [Davies Liu] fix generated safe projection d1ef81a [Davies Liu] fix failed tests 737b3d3 [Davies Liu] Merge branch 'master' of github.com:apache/spark into wider_table ffcd132 [Davies Liu] address comments 1b95be4 [Davies Liu] put the generated class into sql package 77ed72d [Davies Liu] address comments 4518e17 [Davies Liu] Merge branch 'master' of github.com:apache/spark into wider_table 75ccd01 [Davies Liu] Merge branch 'master' of github.com:apache/spark into wider_table 495e932 [Davies Liu] support wider table with more than 1k columns for generated projections
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala48
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala43
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala52
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala122
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala82
6 files changed, 207 insertions, 142 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 7b41c9a3f3..c21f4d626a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen
import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
import scala.language.existentials
import com.google.common.cache.{CacheBuilder, CacheLoader}
@@ -265,6 +266,45 @@ class CodeGenContext {
def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt)
def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt))
+
+ /**
+ * Splits the generated code of expressions into multiple functions, because function has
+ * 64kb code size limit in JVM
+ *
+ * @param row the variable name of row that is used by expressions
+ */
+ def splitExpressions(row: String, expressions: Seq[String]): String = {
+ val blocks = new ArrayBuffer[String]()
+ val blockBuilder = new StringBuilder()
+ for (code <- expressions) {
+ // We can't know how many byte code will be generated, so use the number of bytes as limit
+ if (blockBuilder.length > 64 * 1000) {
+ blocks.append(blockBuilder.toString())
+ blockBuilder.clear()
+ }
+ blockBuilder.append(code)
+ }
+ blocks.append(blockBuilder.toString())
+
+ if (blocks.length == 1) {
+ // inline execution if only one block
+ blocks.head
+ } else {
+ val apply = freshName("apply")
+ val functions = blocks.zipWithIndex.map { case (body, i) =>
+ val name = s"${apply}_$i"
+ val code = s"""
+ |private void $name(InternalRow $row) {
+ | $body
+ |}
+ """.stripMargin
+ addNewFunction(name, code)
+ name
+ }
+
+ functions.map(name => s"$name($row);").mkString("\n")
+ }
+ }
}
/**
@@ -289,15 +329,15 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
protected def declareMutableStates(ctx: CodeGenContext): String = {
ctx.mutableStates.map { case (javaType, variableName, _) =>
s"private $javaType $variableName;"
- }.mkString
+ }.mkString("\n")
}
protected def initMutableStates(ctx: CodeGenContext): String = {
- ctx.mutableStates.map(_._3).mkString
+ ctx.mutableStates.map(_._3).mkString("\n")
}
protected def declareAddedFunctions(ctx: CodeGenContext): String = {
- ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString
+ ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString("\n")
}
/**
@@ -328,6 +368,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
private[this] def doCompile(code: String): GeneratedClass = {
val evaluator = new ClassBodyEvaluator()
evaluator.setParentClassLoader(getClass.getClassLoader)
+ // Cannot be under package codegen, or fail with java.lang.InstantiationException
+ evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass")
evaluator.setDefaultImports(Array(
classOf[PlatformDependent].getName,
classOf[InternalRow].getName,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index ac58423cd8..b4d4df8934 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -40,7 +40,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
protected def create(expressions: Seq[Expression]): (() => MutableProjection) = {
val ctx = newCodeGenContext()
- val projectionCode = expressions.zipWithIndex.map {
+ val projectionCodes = expressions.zipWithIndex.map {
case (NoOp, _) => ""
case (e, i) =>
val evaluationCode = e.gen(ctx)
@@ -65,49 +65,21 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
"""
}
}
- // collect projections into blocks as function has 64kb codesize limit in JVM
- val projectionBlocks = new ArrayBuffer[String]()
- val blockBuilder = new StringBuilder()
- for (projection <- projectionCode) {
- if (blockBuilder.length > 16 * 1000) {
- projectionBlocks.append(blockBuilder.toString())
- blockBuilder.clear()
- }
- blockBuilder.append(projection)
- }
- projectionBlocks.append(blockBuilder.toString())
-
- val (projectionFuns, projectionCalls) = {
- // inline execution if codesize limit was not broken
- if (projectionBlocks.length == 1) {
- ("", projectionBlocks.head)
- } else {
- (
- projectionBlocks.zipWithIndex.map { case (body, i) =>
- s"""
- |private void apply$i(InternalRow i) {
- | $body
- |}
- """.stripMargin
- }.mkString,
- projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n")
- )
- }
- }
+ val allProjections = ctx.splitExpressions("i", projectionCodes)
val code = s"""
public Object generate($exprType[] expr) {
- return new SpecificProjection(expr);
+ return new SpecificMutableProjection(expr);
}
- class SpecificProjection extends ${classOf[BaseMutableProjection].getName} {
+ class SpecificMutableProjection extends ${classOf[BaseMutableProjection].getName} {
private $exprType[] expressions;
private $mutableRowType mutableRow;
${declareMutableStates(ctx)}
${declareAddedFunctions(ctx)}
- public SpecificProjection($exprType[] expr) {
+ public SpecificMutableProjection($exprType[] expr) {
expressions = expr;
mutableRow = new $genericMutableRowType(${expressions.size});
${initMutableStates(ctx)}
@@ -123,12 +95,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
return (InternalRow) mutableRow;
}
- $projectionFuns
-
public Object apply(Object _i) {
InternalRow i = (InternalRow) _i;
- $projectionCalls
-
+ $allProjections
return mutableRow;
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index ef08ddf041..7ad352d7ce 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.expressions.codegen
-import scala.collection.mutable.ArrayBuffer
-
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
import org.apache.spark.sql.types._
@@ -43,6 +41,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val tmp = ctx.freshName("tmp")
val output = ctx.freshName("safeRow")
val values = ctx.freshName("values")
+ // These expressions could be splitted into multiple functions
+ ctx.addMutableState("Object[]", values, s"this.$values = null;")
+
val rowClass = classOf[GenericInternalRow].getName
val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) =>
@@ -53,12 +54,12 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
$values[$i] = ${converter.primitive};
}
"""
- }.mkString("\n")
-
+ }
+ val allFields = ctx.splitExpressions(tmp, fieldWriters)
val code = s"""
final InternalRow $tmp = $input;
- final Object[] $values = new Object[${schema.length}];
- $fieldWriters
+ this.$values = new Object[${schema.length}];
+ $allFields
final InternalRow $output = new $rowClass($values);
"""
@@ -128,7 +129,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
protected def create(expressions: Seq[Expression]): Projection = {
val ctx = newCodeGenContext()
- val projectionCode = expressions.zipWithIndex.map {
+ val expressionCodes = expressions.zipWithIndex.map {
case (NoOp, _) => ""
case (e, i) =>
val evaluationCode = e.gen(ctx)
@@ -143,36 +144,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
}
"""
}
- // collect projections into blocks as function has 64kb codesize limit in JVM
- val projectionBlocks = new ArrayBuffer[String]()
- val blockBuilder = new StringBuilder()
- for (projection <- projectionCode) {
- if (blockBuilder.length > 16 * 1000) {
- projectionBlocks.append(blockBuilder.toString())
- blockBuilder.clear()
- }
- blockBuilder.append(projection)
- }
- projectionBlocks.append(blockBuilder.toString())
-
- val (projectionFuns, projectionCalls) = {
- // inline it if we have only one block
- if (projectionBlocks.length == 1) {
- ("", projectionBlocks.head)
- } else {
- (
- projectionBlocks.zipWithIndex.map { case (body, i) =>
- s"""
- |private void apply$i(InternalRow i) {
- | $body
- |}
- """.stripMargin
- }.mkString,
- projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n")
- )
- }
- }
-
+ val allExpressions = ctx.splitExpressions("i", expressionCodes)
val code = s"""
public Object generate($exprType[] expr) {
return new SpecificSafeProjection(expr);
@@ -183,6 +155,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
private $exprType[] expressions;
private $mutableRowType mutableRow;
${declareMutableStates(ctx)}
+ ${declareAddedFunctions(ctx)}
public SpecificSafeProjection($exprType[] expr) {
expressions = expr;
@@ -190,12 +163,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
${initMutableStates(ctx)}
}
- $projectionFuns
-
public Object apply(Object _i) {
InternalRow i = (InternalRow) _i;
- $projectionCalls
-
+ $allExpressions
return mutableRow;
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index d8912df694..29f6a7b981 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.PlatformDependent
/**
* Generates a [[Projection]] that returns an [[UnsafeRow]].
@@ -41,8 +40,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
private val ArrayWriter = classOf[UnsafeRowWriters.ArrayWriter].getName
private val MapWriter = classOf[UnsafeRowWriters.MapWriter].getName
- private val PlatformDependent = classOf[PlatformDependent].getName
-
/** Returns true iff we support this data type. */
def canSupport(dataType: DataType): Boolean = dataType match {
case NullType => true
@@ -56,19 +53,19 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
def genAdditionalSize(dt: DataType, ev: GeneratedExpressionCode): String = dt match {
case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
- s" + $DecimalWriter.getSize(${ev.primitive})"
+ s"$DecimalWriter.getSize(${ev.primitive})"
case StringType =>
- s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))"
+ s"${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive})"
case BinaryType =>
- s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))"
+ s"${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive})"
case CalendarIntervalType =>
- s" + (${ev.isNull} ? 0 : 16)"
+ s"${ev.isNull} ? 0 : 16"
case _: StructType =>
- s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))"
+ s"${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive})"
case _: ArrayType =>
- s" + (${ev.isNull} ? 0 : $ArrayWriter.getSize(${ev.primitive}))"
+ s"${ev.isNull} ? 0 : $ArrayWriter.getSize(${ev.primitive})"
case _: MapType =>
- s" + (${ev.isNull} ? 0 : $MapWriter.getSize(${ev.primitive}))"
+ s"${ev.isNull} ? 0 : $MapWriter.getSize(${ev.primitive})"
case _ => ""
}
@@ -125,64 +122,69 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
*/
private def createCodeForStruct(
ctx: CodeGenContext,
+ row: String,
inputs: Seq[GeneratedExpressionCode],
inputTypes: Seq[DataType]): GeneratedExpressionCode = {
+ val fixedSize = 8 * inputTypes.length + UnsafeRow.calculateBitSetWidthInBytes(inputTypes.length)
+
val output = ctx.freshName("convertedStruct")
- ctx.addMutableState("UnsafeRow", output, s"$output = new UnsafeRow();")
+ ctx.addMutableState("UnsafeRow", output, s"this.$output = new UnsafeRow();")
val buffer = ctx.freshName("buffer")
- ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];")
- val numBytes = ctx.freshName("numBytes")
+ ctx.addMutableState("byte[]", buffer, s"this.$buffer = new byte[$fixedSize];")
val cursor = ctx.freshName("cursor")
+ ctx.addMutableState("int", cursor, s"this.$cursor = 0;")
+ val tmp = ctx.freshName("tmpBuffer")
- val convertedFields = inputTypes.zip(inputs).map { case (dt, input) =>
- createConvertCode(ctx, input, dt)
- }
-
- val fixedSize = 8 * inputTypes.length + UnsafeRow.calculateBitSetWidthInBytes(inputTypes.length)
- val additionalSize = inputTypes.zip(convertedFields).map { case (dt, ev) =>
- genAdditionalSize(dt, ev)
- }.mkString("")
-
- val fieldWriters = inputTypes.zip(convertedFields).zipWithIndex.map { case ((dt, ev), i) =>
- val update = genFieldWriter(ctx, dt, ev, output, i, cursor)
- if (dt.isInstanceOf[DecimalType]) {
- // Can't call setNullAt() for DecimalType
+ val convertedFields = inputTypes.zip(inputs).zipWithIndex.map { case ((dt, input), i) =>
+ val ev = createConvertCode(ctx, input, dt)
+ val growBuffer = if (!UnsafeRow.isFixedLength(dt)) {
+ val numBytes = ctx.freshName("numBytes")
s"""
+ int $numBytes = $cursor + (${genAdditionalSize(dt, ev)});
+ if ($buffer.length < $numBytes) {
+ // This will not happen frequently, because the buffer is re-used.
+ byte[] $tmp = new byte[$numBytes * 2];
+ PlatformDependent.copyMemory($buffer, PlatformDependent.BYTE_ARRAY_OFFSET,
+ $tmp, PlatformDependent.BYTE_ARRAY_OFFSET, $buffer.length);
+ $buffer = $tmp;
+ }
+ $output.pointTo($buffer, PlatformDependent.BYTE_ARRAY_OFFSET,
+ ${inputTypes.length}, $numBytes);
+ """
+ } else {
+ ""
+ }
+ val update = dt match {
+ case dt: DecimalType if dt.precision > Decimal.MAX_LONG_DIGITS =>
+ // Can't call setNullAt() for DecimalType
+ s"""
if (${ev.isNull}) {
- $cursor += $DecimalWriter.write($output, $i, $cursor, null);
+ $cursor += $DecimalWriter.write($output, $i, $cursor, null);
} else {
- $update;
+ ${genFieldWriter(ctx, dt, ev, output, i, cursor)};
}
"""
- } else {
- s"""
+ case _ =>
+ s"""
if (${ev.isNull}) {
$output.setNullAt($i);
} else {
- $update;
+ ${genFieldWriter(ctx, dt, ev, output, i, cursor)};
}
"""
}
- }.mkString("\n")
+ s"""
+ ${ev.code}
+ $growBuffer
+ $update
+ """
+ }
val code = s"""
- ${convertedFields.map(_.code).mkString("\n")}
-
- final int $numBytes = $fixedSize $additionalSize;
- if ($numBytes > $buffer.length) {
- $buffer = new byte[$numBytes];
- }
-
- $output.pointTo(
- $buffer,
- $PlatformDependent.BYTE_ARRAY_OFFSET,
- ${inputTypes.length},
- $numBytes);
-
- int $cursor = $fixedSize;
-
- $fieldWriters
+ $cursor = $fixedSize;
+ $output.pointTo($buffer, PlatformDependent.BYTE_ARRAY_OFFSET, ${inputTypes.length}, $cursor);
+ ${ctx.splitExpressions(row, convertedFields)}
"""
GeneratedExpressionCode(code, "false", output)
}
@@ -265,17 +267,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
// Should we do word align?
val elementSize = elementType.defaultSize
s"""
- $PlatformDependent.UNSAFE.put${ctx.primitiveTypeName(elementType)}(
+ PlatformDependent.UNSAFE.put${ctx.primitiveTypeName(elementType)}(
$buffer,
- $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor,
+ PlatformDependent.BYTE_ARRAY_OFFSET + $cursor,
${convertedElement.primitive});
$cursor += $elementSize;
"""
case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
s"""
- $PlatformDependent.UNSAFE.putLong(
+ PlatformDependent.UNSAFE.putLong(
$buffer,
- $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor,
+ PlatformDependent.BYTE_ARRAY_OFFSET + $cursor,
${convertedElement.primitive}.toUnscaledLong());
$cursor += 8;
"""
@@ -284,7 +286,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
s"""
$cursor += $writer.write(
$buffer,
- $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor,
+ PlatformDependent.BYTE_ARRAY_OFFSET + $cursor,
$elements[$index]);
"""
}
@@ -318,14 +320,14 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
for (int $index = 0; $index < $numElements; $index++) {
if ($checkNull) {
// If element is null, write the negative value address into offset region.
- $PlatformDependent.UNSAFE.putInt(
+ PlatformDependent.UNSAFE.putInt(
$buffer,
- $PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index,
+ PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index,
-$cursor);
} else {
- $PlatformDependent.UNSAFE.putInt(
+ PlatformDependent.UNSAFE.putInt(
$buffer,
- $PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index,
+ PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index,
$cursor);
$writeElement
@@ -334,7 +336,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
$output.pointTo(
$buffer,
- $PlatformDependent.BYTE_ARRAY_OFFSET,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
$numElements,
$numBytes);
}
@@ -400,7 +402,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val fieldIsNull = s"$tmp.isNullAt($i)"
GeneratedExpressionCode("", fieldIsNull, getFieldCode)
}
- val converter = createCodeForStruct(ctx, fieldEvals, fieldTypes)
+ val converter = createCodeForStruct(ctx, tmp, fieldEvals, fieldTypes)
val code = s"""
${input.code}
UnsafeRow $output = null;
@@ -427,7 +429,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
def createCode(ctx: CodeGenContext, expressions: Seq[Expression]): GeneratedExpressionCode = {
val exprEvals = expressions.map(e => e.gen(ctx))
val exprTypes = expressions.map(_.dataType)
- createCodeForStruct(ctx, exprEvals, exprTypes)
+ createCodeForStruct(ctx, "i", exprEvals, exprTypes)
}
protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
index 30b51dd83f..8aaa5b4300 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
@@ -155,7 +155,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
|$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32));
""".stripMargin
}
- }.mkString
+ }.mkString("\n")
// ------------------------ Finally, put everything together --------------------------- //
val code = s"""
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
new file mode 100644
index 0000000000..8c7ee8720f
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.{StringType, IntegerType, StructField, StructType}
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * A test suite for generated projections
+ */
+class GeneratedProjectionSuite extends SparkFunSuite {
+
+ test("generated projections on wider table") {
+ val N = 1000
+ val wideRow1 = new GenericInternalRow((1 to N).toArray[Any])
+ val schema1 = StructType((1 to N).map(i => StructField("", IntegerType)))
+ val wideRow2 = new GenericInternalRow(
+ (1 to N).map(i => UTF8String.fromString(i.toString)).toArray[Any])
+ val schema2 = StructType((1 to N).map(i => StructField("", StringType)))
+ val joined = new JoinedRow(wideRow1, wideRow2)
+ val joinedSchema = StructType(schema1 ++ schema2)
+ val nested = new JoinedRow(InternalRow(joined, joined), joined)
+ val nestedSchema = StructType(
+ Seq(StructField("", joinedSchema), StructField("", joinedSchema)) ++ joinedSchema)
+
+ // test generated UnsafeProjection
+ val unsafeProj = UnsafeProjection.create(nestedSchema)
+ val unsafe: UnsafeRow = unsafeProj(nested)
+ (0 until N).foreach { i =>
+ val s = UTF8String.fromString((i + 1).toString)
+ assert(i + 1 === unsafe.getInt(i + 2))
+ assert(s === unsafe.getUTF8String(i + 2 + N))
+ assert(i + 1 === unsafe.getStruct(0, N * 2).getInt(i))
+ assert(s === unsafe.getStruct(0, N * 2).getUTF8String(i + N))
+ assert(i + 1 === unsafe.getStruct(1, N * 2).getInt(i))
+ assert(s === unsafe.getStruct(1, N * 2).getUTF8String(i + N))
+ }
+
+ // test generated SafeProjection
+ val safeProj = FromUnsafeProjection(nestedSchema)
+ val result = safeProj(unsafe)
+ // Can't compare GenericInternalRow with JoinedRow directly
+ (0 until N).foreach { i =>
+ val r = i + 1
+ val s = UTF8String.fromString((i + 1).toString)
+ assert(r === result.getInt(i + 2))
+ assert(s === result.getUTF8String(i + 2 + N))
+ assert(r === result.getStruct(0, N * 2).getInt(i))
+ assert(s === result.getStruct(0, N * 2).getUTF8String(i + N))
+ assert(r === result.getStruct(1, N * 2).getInt(i))
+ assert(s === result.getStruct(1, N * 2).getUTF8String(i + N))
+ }
+
+ // test generated MutableProjection
+ val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) =>
+ BoundReference(i, f.dataType, true)
+ }
+ val mutableProj = GenerateMutableProjection.generate(exprs)()
+ val row1 = mutableProj(result)
+ assert(result === row1)
+ val row2 = mutableProj(result)
+ assert(result === row2)
+ }
+}