aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-09-09 10:57:29 -0700
committerDavies Liu <davies.liu@gmail.com>2015-09-09 10:57:29 -0700
commit71da1633c4dcc4e748fbe3b3236af90032b695ae (patch)
tree249ad56ab7309df9e3c4441e90ca1cf1589896e7 /sql
parentc0052d8d09eebadadb5ed35ac512caaf73919551 (diff)
downloadspark-71da1633c4dcc4e748fbe3b3236af90032b695ae.tar.gz
spark-71da1633c4dcc4e748fbe3b3236af90032b695ae.tar.bz2
spark-71da1633c4dcc4e748fbe3b3236af90032b695ae.zip
[SPARK-10461] [SQL] make sure `input.primitive` is always variable name not code at `GenerateUnsafeProjection`
When we generate unsafe code inside `createCodeForXXX`, we always assign the `input.primitive` to a temp variable in case `input.primitive` is expression code. This PR did some refactor to make sure `input.primitive` is always variable name, and some other typo and style fixes. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #8613 from cloud-fan/minor.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala85
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala33
5 files changed, 75 insertions, 67 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index bf96248fea..da3103b4eb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -84,11 +84,11 @@ class CodeGenContext {
/**
* Holding all the functions those will be added into generated class.
*/
- val addedFuntions: mutable.Map[String, String] =
+ val addedFunctions: mutable.Map[String, String] =
mutable.Map.empty[String, String]
def addNewFunction(funcName: String, funcCode: String): Unit = {
- addedFuntions += ((funcName, funcCode))
+ addedFunctions += ((funcName, funcCode))
}
final val JAVA_BOOLEAN = "boolean"
@@ -298,8 +298,8 @@ class CodeGenContext {
| $body
|}
""".stripMargin
- addNewFunction(name, code)
- name
+ addNewFunction(name, code)
+ name
}
functions.map(name => s"$name($row);").mkString("\n")
@@ -337,7 +337,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
}
protected def declareAddedFunctions(ctx: CodeGenContext): String = {
- ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString("\n")
+ ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n")
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index b4d4df8934..793023b9fb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.expressions.codegen
-import scala.collection.mutable.ArrayBuffer
-
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
import org.apache.spark.sql.types.DecimalType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index c744e84d82..2164ddf03d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -48,7 +48,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
val columns = expressions.zipWithIndex.map {
case (e, i) =>
s"private ${ctx.javaType(e.dataType)} c$i = ${ctx.defaultValue(e.dataType)};\n"
- }.mkString("\n ")
+ }.mkString("\n")
val initColumns = expressions.zipWithIndex.map {
case (e, i) =>
@@ -67,18 +67,18 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
val getCases = (0 until expressions.size).map { i =>
s"case $i: return c$i;"
- }.mkString("\n ")
+ }.mkString("\n")
val updateCases = expressions.zipWithIndex.map { case (e, i) =>
s"case $i: { c$i = (${ctx.boxedType(e.dataType)})value; return;}"
- }.mkString("\n ")
+ }.mkString("\n")
val specificAccessorFunctions = ctx.primitiveTypes.map { jt =>
val cases = expressions.zipWithIndex.flatMap {
case (e, i) if ctx.javaType(e.dataType) == jt =>
Some(s"case $i: return c$i;")
case _ => None
- }.mkString("\n ")
+ }.mkString("\n")
if (cases.length > 0) {
val getter = "get" + ctx.primitiveTypeName(jt)
s"""
@@ -103,7 +103,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
case (e, i) if ctx.javaType(e.dataType) == jt =>
Some(s"case $i: { c$i = value; return; }")
case _ => None
- }.mkString("\n ")
+ }.mkString("\n")
if (cases.length > 0) {
val setter = "set" + ctx.primitiveTypeName(jt)
s"""
@@ -152,7 +152,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
val copyColumns = expressions.zipWithIndex.map { case (e, i) =>
s"""if (!nullBits[$i]) arr[$i] = c$i;"""
- }.mkString("\n ")
+ }.mkString("\n")
val code = s"""
public SpecificProjection generate($exprType[] expr) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index b570fe86db..03c5f449bf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -134,7 +134,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
ctx.addMutableState("byte[]", buffer, s"this.$buffer = new byte[$fixedSize];")
val cursor = ctx.freshName("cursor")
ctx.addMutableState("int", cursor, s"this.$cursor = 0;")
- val tmp = ctx.freshName("tmpBuffer")
+ val tmpBuffer = ctx.freshName("tmpBuffer")
val convertedFields = inputTypes.zip(inputs).zipWithIndex.map { case ((dt, input), i) =>
val ev = createConvertCode(ctx, input, dt)
@@ -144,10 +144,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
int $numBytes = $cursor + (${genAdditionalSize(dt, ev)});
if ($buffer.length < $numBytes) {
// This will not happen frequently, because the buffer is re-used.
- byte[] $tmp = new byte[$numBytes * 2];
+ byte[] $tmpBuffer = new byte[$numBytes * 2];
Platform.copyMemory($buffer, Platform.BYTE_ARRAY_OFFSET,
- $tmp, Platform.BYTE_ARRAY_OFFSET, $buffer.length);
- $buffer = $tmp;
+ $tmpBuffer, Platform.BYTE_ARRAY_OFFSET, $buffer.length);
+ $buffer = $tmpBuffer;
}
$output.pointTo($buffer, Platform.BYTE_ARRAY_OFFSET, ${inputTypes.length}, $numBytes);
"""
@@ -207,20 +207,22 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val buffer = ctx.freshName("buffer")
ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];")
val outputIsNull = ctx.freshName("isNull")
- val tmp = ctx.freshName("tmp")
val numElements = ctx.freshName("numElements")
val fixedSize = ctx.freshName("fixedSize")
val numBytes = ctx.freshName("numBytes")
val elements = ctx.freshName("elements")
val cursor = ctx.freshName("cursor")
val index = ctx.freshName("index")
+ val elementName = ctx.freshName("elementName")
- val element = GeneratedExpressionCode(
- code = "",
- isNull = s"$tmp.isNullAt($index)",
- primitive = s"${ctx.getValue(tmp, elementType, index)}"
- )
- val convertedElement: GeneratedExpressionCode = createConvertCode(ctx, element, elementType)
+ val element = {
+ val code = s"${ctx.javaType(elementType)} $elementName = " +
+ s"${ctx.getValue(input.primitive, elementType, index)};"
+ val isNull = s"${input.primitive}.isNullAt($index)"
+ GeneratedExpressionCode(code, isNull, elementName)
+ }
+
+ val convertedElement = createConvertCode(ctx, element, elementType)
// go through the input array to calculate how many bytes we need.
val calculateNumBytes = elementType match {
@@ -272,6 +274,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
// Should we do word align?
val elementSize = elementType.defaultSize
s"""
+ ${convertedElement.code}
Platform.put${ctx.primitiveTypeName(elementType)}(
$buffer,
Platform.BYTE_ARRAY_OFFSET + $cursor,
@@ -280,6 +283,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
"""
case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
s"""
+ ${convertedElement.code}
Platform.putLong(
$buffer,
Platform.BYTE_ARRAY_OFFSET + $cursor,
@@ -307,11 +311,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
${input.code}
final boolean $outputIsNull = ${input.isNull};
if (!$outputIsNull) {
- final ArrayData $tmp = ${input.primitive};
- if ($tmp instanceof UnsafeArrayData) {
- $output = (UnsafeArrayData) $tmp;
+ if (${input.primitive} instanceof UnsafeArrayData) {
+ $output = (UnsafeArrayData) ${input.primitive};
} else {
- final int $numElements = $tmp.numElements();
+ final int $numElements = ${input.primitive}.numElements();
final int $fixedSize = 4 * $numElements;
int $numBytes = $fixedSize;
@@ -350,29 +353,31 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
valueType: DataType): GeneratedExpressionCode = {
val output = ctx.freshName("convertedMap")
val outputIsNull = ctx.freshName("isNull")
- val tmp = ctx.freshName("tmp")
-
- val keyArray = GeneratedExpressionCode(
- code = "",
- isNull = "false",
- primitive = s"$tmp.keyArray()"
- )
- val valueArray = GeneratedExpressionCode(
- code = "",
- isNull = "false",
- primitive = s"$tmp.valueArray()"
- )
- val convertedKeys: GeneratedExpressionCode = createCodeForArray(ctx, keyArray, keyType)
- val convertedValues: GeneratedExpressionCode = createCodeForArray(ctx, valueArray, valueType)
+ val keyArrayName = ctx.freshName("keyArrayName")
+ val valueArrayName = ctx.freshName("valueArrayName")
+
+ val keyArray = {
+ val code = s"ArrayData $keyArrayName = ${input.primitive}.keyArray();"
+ val isNull = "false"
+ GeneratedExpressionCode(code, isNull, keyArrayName)
+ }
+
+ val valueArray = {
+ val code = s"ArrayData $valueArrayName = ${input.primitive}.valueArray();"
+ val isNull = "false"
+ GeneratedExpressionCode(code, isNull, valueArrayName)
+ }
+
+ val convertedKeys = createCodeForArray(ctx, keyArray, keyType)
+ val convertedValues = createCodeForArray(ctx, valueArray, valueType)
val code = s"""
${input.code}
final boolean $outputIsNull = ${input.isNull};
UnsafeMapData $output = null;
if (!$outputIsNull) {
- final MapData $tmp = ${input.primitive};
- if ($tmp instanceof UnsafeMapData) {
- $output = (UnsafeMapData) $tmp;
+ if (${input.primitive} instanceof UnsafeMapData) {
+ $output = (UnsafeMapData) ${input.primitive};
} else {
${convertedKeys.code}
${convertedValues.code}
@@ -393,22 +398,22 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case t: StructType =>
val output = ctx.freshName("convertedStruct")
val outputIsNull = ctx.freshName("isNull")
- val tmp = ctx.freshName("tmp")
val fieldTypes = t.fields.map(_.dataType)
val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
- val getFieldCode = ctx.getValue(tmp, dt, i.toString)
- val fieldIsNull = s"$tmp.isNullAt($i)"
- GeneratedExpressionCode("", fieldIsNull, getFieldCode)
+ val fieldName = ctx.freshName("fieldName")
+ val code = s"${ctx.javaType(dt)} $fieldName = " +
+ s"${ctx.getValue(input.primitive, dt, i.toString)};"
+ val isNull = s"${input.primitive}.isNullAt($i)"
+ GeneratedExpressionCode(code, isNull, fieldName)
}
- val converter = createCodeForStruct(ctx, tmp, fieldEvals, fieldTypes)
+ val converter = createCodeForStruct(ctx, input.primitive, fieldEvals, fieldTypes)
val code = s"""
${input.code}
UnsafeRow $output = null;
final boolean $outputIsNull = ${input.isNull};
if (!$outputIsNull) {
- final InternalRow $tmp = ${input.primitive};
- if ($tmp instanceof UnsafeRow) {
- $output = (UnsafeRow) $tmp;
+ if (${input.primitive} instanceof UnsafeRow) {
+ $output = (UnsafeRow) ${input.primitive};
} else {
${converter.code}
$output = ${converter.primitive};
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 1c54671973..82eab5fb3d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -48,21 +48,22 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val arrayClass = classOf[GenericArrayData].getName
+ val values = ctx.freshName("values")
s"""
final boolean ${ev.isNull} = false;
- final Object[] values = new Object[${children.size}];
+ final Object[] $values = new Object[${children.size}];
""" +
children.zipWithIndex.map { case (e, i) =>
val eval = e.gen(ctx)
eval.code + s"""
if (${eval.isNull}) {
- values[$i] = null;
+ $values[$i] = null;
} else {
- values[$i] = ${eval.primitive};
+ $values[$i] = ${eval.primitive};
}
"""
}.mkString("\n") +
- s"final ${ctx.javaType(dataType)} ${ev.primitive} = new $arrayClass(values);"
+ s"final ArrayData ${ev.primitive} = new $arrayClass($values);"
}
override def prettyName: String = "array"
@@ -94,21 +95,23 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val rowClass = classOf[GenericMutableRow].getName
+ val rowClass = classOf[GenericInternalRow].getName
+ val values = ctx.freshName("values")
s"""
boolean ${ev.isNull} = false;
- final $rowClass ${ev.primitive} = new $rowClass(${children.size});
+ final Object[] $values = new Object[${children.size}];
""" +
children.zipWithIndex.map { case (e, i) =>
val eval = e.gen(ctx)
eval.code + s"""
if (${eval.isNull}) {
- ${ev.primitive}.update($i, null);
+ $values[$i] = null;
} else {
- ${ev.primitive}.update($i, ${eval.primitive});
+ $values[$i] = ${eval.primitive};
}
"""
- }.mkString("\n")
+ }.mkString("\n") +
+ s"final InternalRow ${ev.primitive} = new $rowClass($values);"
}
override def prettyName: String = "struct"
@@ -161,21 +164,23 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val rowClass = classOf[GenericMutableRow].getName
+ val rowClass = classOf[GenericInternalRow].getName
+ val values = ctx.freshName("values")
s"""
boolean ${ev.isNull} = false;
- final $rowClass ${ev.primitive} = new $rowClass(${valExprs.size});
+ final Object[] $values = new Object[${valExprs.size}];
""" +
valExprs.zipWithIndex.map { case (e, i) =>
val eval = e.gen(ctx)
eval.code + s"""
if (${eval.isNull}) {
- ${ev.primitive}.update($i, null);
+ $values[$i] = null;
} else {
- ${ev.primitive}.update($i, ${eval.primitive});
+ $values[$i] = ${eval.primitive};
}
"""
- }.mkString("\n")
+ }.mkString("\n") +
+ s"final InternalRow ${ev.primitive} = new $rowClass($values);"
}
override def prettyName: String = "named_struct"