aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-08-04 14:40:46 -0700
committerReynold Xin <rxin@databricks.com>2015-08-04 14:40:46 -0700
commitf4b1ac08a1327e6d0ddc317cdf3997a0f68dec72 (patch)
treea97a977ae1176db91f7dcebec3d96a6dd46faadc /sql
parenta0cc01759b0c2cecf340c885d391976eb4e3fad6 (diff)
downloadspark-f4b1ac08a1327e6d0ddc317cdf3997a0f68dec72.tar.gz
spark-f4b1ac08a1327e6d0ddc317cdf3997a0f68dec72.tar.bz2
spark-f4b1ac08a1327e6d0ddc317cdf3997a0f68dec72.zip
[SPARK-9553][SQL] remove the no-longer-necessary createCode and createStructCode, and replace the usage
Author: Wenchen Fan <cloud0fan@outlook.com> Closes #7890 from cloud-fan/minor and squashes the following commits: c3b1be3 [Wenchen Fan] fix style b0cbe2e [Wenchen Fan] remove the createCode and createStructCode, and replace the usage of them by createStructCode
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala161
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala10
2 files changed, 17 insertions, 154 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index fc3ecf5451..71f8ea09f0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -117,161 +117,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
}
/**
- * Generates the code to create an [[UnsafeRow]] object based on the input expressions.
- * @param ctx context for code generation
- * @param ev specifies the name of the variable for the output [[UnsafeRow]] object
- * @param expressions input expressions
- * @return generated code to put the expression output into an [[UnsafeRow]]
- */
- def createCode(ctx: CodeGenContext, ev: GeneratedExpressionCode, expressions: Seq[Expression])
- : String = {
-
- val ret = ev.primitive
- ctx.addMutableState("UnsafeRow", ret, s"$ret = new UnsafeRow();")
- val buffer = ctx.freshName("buffer")
- ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];")
- val cursor = ctx.freshName("cursor")
- val numBytes = ctx.freshName("numBytes")
-
- val exprs = expressions.map { e => e.dataType match {
- case st: StructType => createCodeForStruct(ctx, e.gen(ctx), st)
- case _ => e.gen(ctx)
- }}
- val allExprs = exprs.map(_.code).mkString("\n")
-
- val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length)
- val additionalSize = expressions.zipWithIndex.map {
- case (e, i) => genAdditionalSize(e.dataType, exprs(i))
- }.mkString("")
-
- val writers = expressions.zipWithIndex.map { case (e, i) =>
- val update = genFieldWriter(ctx, e.dataType, exprs(i), ret, i, cursor)
- s"""if (${exprs(i).isNull}) {
- $ret.setNullAt($i);
- } else {
- $update;
- }"""
- }.mkString("\n ")
-
- s"""
- $allExprs
- int $numBytes = $fixedSize $additionalSize;
- if ($numBytes > $buffer.length) {
- $buffer = new byte[$numBytes];
- }
-
- $ret.pointTo(
- $buffer,
- $PlatformDependent.BYTE_ARRAY_OFFSET,
- ${expressions.size},
- $numBytes);
- int $cursor = $fixedSize;
-
- $writers
- boolean ${ev.isNull} = false;
- """
- }
-
- /**
- * Generates the Java code to convert a struct (backed by InternalRow) to UnsafeRow.
- *
- * This function also handles nested structs by recursively generating the code to do conversion.
- *
- * @param ctx code generation context
- * @param input the input struct, identified by a [[GeneratedExpressionCode]]
- * @param schema schema of the struct field
- */
- // TODO: refactor createCode and this function to reduce code duplication.
- private def createCodeForStruct(
- ctx: CodeGenContext,
- input: GeneratedExpressionCode,
- schema: StructType): GeneratedExpressionCode = {
-
- val isNull = input.isNull
- val primitive = ctx.freshName("structConvert")
- ctx.addMutableState("UnsafeRow", primitive, s"$primitive = new UnsafeRow();")
- val buffer = ctx.freshName("buffer")
- ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];")
- val cursor = ctx.freshName("cursor")
-
- val exprs: Seq[GeneratedExpressionCode] = schema.map(_.dataType).zipWithIndex.map {
- case (dt, i) => dt match {
- case st: StructType =>
- val nestedStructEv = GeneratedExpressionCode(
- code = "",
- isNull = s"${input.primitive}.isNullAt($i)",
- primitive = s"${ctx.getValue(input.primitive, dt, i.toString)}"
- )
- createCodeForStruct(ctx, nestedStructEv, st)
- case _ =>
- GeneratedExpressionCode(
- code = "",
- isNull = s"${input.primitive}.isNullAt($i)",
- primitive = s"${ctx.getValue(input.primitive, dt, i.toString)}"
- )
- }
- }
- val allExprs = exprs.map(_.code).mkString("\n")
-
- val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length)
- val additionalSize = schema.toSeq.map(_.dataType).zip(exprs).map { case (dt, ev) =>
- genAdditionalSize(dt, ev)
- }.mkString("")
-
- val writers = schema.toSeq.map(_.dataType).zip(exprs).zipWithIndex.map { case ((dt, ev), i) =>
- val update = genFieldWriter(ctx, dt, ev, primitive, i, cursor)
- s"""
- if (${exprs(i).isNull}) {
- $primitive.setNullAt($i);
- } else {
- $update;
- }
- """
- }.mkString("\n ")
-
- // Note that we add a shortcut here for performance: if the input is already an UnsafeRow,
- // just copy the bytes directly into our buffer space without running any conversion.
- // We also had to use a hack to introduce a "tmp" variable, to avoid the Java compiler from
- // complaining that a GenericMutableRow (generated by expressions) cannot be cast to UnsafeRow.
- val tmp = ctx.freshName("tmp")
- val numBytes = ctx.freshName("numBytes")
- val code = s"""
- |${input.code}
- |if (!${input.isNull}) {
- | Object $tmp = (Object) ${input.primitive};
- | if ($tmp instanceof UnsafeRow) {
- | $primitive = (UnsafeRow) $tmp;
- | } else {
- | $allExprs
- |
- | int $numBytes = $fixedSize $additionalSize;
- | if ($numBytes > $buffer.length) {
- | $buffer = new byte[$numBytes];
- | }
- |
- | $primitive.pointTo(
- | $buffer,
- | $PlatformDependent.BYTE_ARRAY_OFFSET,
- | ${exprs.size},
- | $numBytes);
- | int $cursor = $fixedSize;
- |
- | $writers
- | }
- |}
- """.stripMargin
-
- GeneratedExpressionCode(code, isNull, primitive)
- }
-
- /**
* Generates the Java code to convert a struct (backed by InternalRow) to UnsafeRow.
*
* @param ctx code generation context
* @param inputs could be the codes for expressions or input struct fields.
* @param inputTypes types of the inputs
*/
- private def createCodeForStruct2(
+ private def createCodeForStruct(
ctx: CodeGenContext,
inputs: Seq[GeneratedExpressionCode],
inputTypes: Seq[DataType]): GeneratedExpressionCode = {
@@ -537,7 +389,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val fieldIsNull = s"$tmp.isNullAt($i)"
GeneratedExpressionCode("", fieldIsNull, getFieldCode)
}
- val converter = createCodeForStruct2(ctx, fieldEvals, fieldTypes)
+ val converter = createCodeForStruct(ctx, fieldEvals, fieldTypes)
val code = s"""
${input.code}
UnsafeRow $output = null;
@@ -561,6 +413,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case _ => input
}
+ def createCode(ctx: CodeGenContext, expressions: Seq[Expression]): GeneratedExpressionCode = {
+ val exprEvals = expressions.map(e => e.gen(ctx))
+ val exprTypes = expressions.map(_.dataType)
+ createCodeForStruct(ctx, exprEvals, exprTypes)
+ }
+
protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
in.map(ExpressionCanonicalizer.execute)
@@ -570,8 +428,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
protected def create(expressions: Seq[Expression]): UnsafeProjection = {
val ctx = newCodeGenContext()
- val exprEvals = expressions.map(e => e.gen(ctx))
- val eval = createCodeForStruct2(ctx, exprEvals, expressions.map(_.dataType))
+ val eval = createCode(ctx, expressions)
val code = s"""
public Object generate($exprType[] exprs) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index a145dfb4bb..4a071e663e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -211,7 +211,10 @@ case class CreateStructUnsafe(children: Seq[Expression]) extends Expression {
override def eval(input: InternalRow): Any = throw new UnsupportedOperationException
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- GenerateUnsafeProjection.createCode(ctx, ev, children)
+ val eval = GenerateUnsafeProjection.createCode(ctx, children)
+ ev.isNull = eval.isNull
+ ev.primitive = eval.primitive
+ eval.code
}
override def prettyName: String = "struct_unsafe"
@@ -246,7 +249,10 @@ case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression
override def eval(input: InternalRow): Any = throw new UnsupportedOperationException
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- GenerateUnsafeProjection.createCode(ctx, ev, valExprs)
+ val eval = GenerateUnsafeProjection.createCode(ctx, valExprs)
+ ev.isNull = eval.isNull
+ ev.primitive = eval.primitive
+ eval.code
}
override def prettyName: String = "named_struct_unsafe"