diff options
author | Davies Liu <davies@databricks.com> | 2015-08-01 21:50:42 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2015-08-01 21:50:42 -0700 |
commit | 57084e0c7c318912208ee31c52d61c14eeddd8f4 (patch) | |
tree | 0e43e0f2a10008b8f279d95099a6cc38e72afb7d /sql/catalyst | |
parent | c1b0cbd762d78bedca0ab564cf9ca0970b7b99d2 (diff) | |
download | spark-57084e0c7c318912208ee31c52d61c14eeddd8f4.tar.gz spark-57084e0c7c318912208ee31c52d61c14eeddd8f4.tar.bz2 spark-57084e0c7c318912208ee31c52d61c14eeddd8f4.zip |
[SPARK-9459] [SQL] use generated FromUnsafeProjection to do deep copy for UTF8String and struct
When accessing a column in UnsafeRow, it's good to avoid the copy, then we should do deep copy when turn the UnsafeRow into generic Row, this PR brings generated FromUnsafeProjection to do that.
This PR also fix the expressions that cache the UTF8String, which should also copy it.
Author: Davies Liu <davies@databricks.com>
Closes #7840 from davies/avoid_copy and squashes the following commits:
230c8a1 [Davies Liu] address comment
fd797c9 [Davies Liu] Merge branch 'master' of github.com:apache/spark into avoid_copy
e095dd0 [Davies Liu] rollback rename
8ef5b0b [Davies Liu] copy String in Columnar
81360b8 [Davies Liu] fix class name
9aecb88 [Davies Liu] use FromUnsafeProjection to do deep copy for UTF8String and struct
Diffstat (limited to 'sql/catalyst')
6 files changed, 224 insertions, 23 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 5a19aa8920..1b475b2492 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -363,15 +363,19 @@ public final class UnsafeRow extends MutableRow { @Override public UTF8String getUTF8String(int ordinal) { assertIndexIsValid(ordinal); - return isNullAt(ordinal) ? null : UTF8String.fromBytes(getBinary(ordinal)); + if (isNullAt(ordinal)) return null; + final long offsetAndSize = getLong(ordinal); + final int offset = (int) (offsetAndSize >> 32); + final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + return UTF8String.fromAddress(baseObject, baseOffset + offset, size); } @Override public byte[] getBinary(int ordinal) { + assertIndexIsValid(ordinal); if (isNullAt(ordinal)) { return null; } else { - assertIndexIsValid(ordinal); final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) (offsetAndSize & ((1L << 32) - 1)); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 000be70f17..83129dc12d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateMutableProjection} -import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} +import org.apache.spark.sql.types.{DataType, Decimal, StructType, _} import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} /** @@ -114,8 +114,7 @@ object UnsafeProjection { * Returns an UnsafeProjection for given Array of DataTypes. */ def create(fields: Array[DataType]): UnsafeProjection = { - val exprs = fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)) - create(exprs) + create(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))) } /** @@ -139,19 +138,27 @@ object UnsafeProjection { /** * A projection that could turn UnsafeRow into GenericInternalRow */ -case class FromUnsafeProjection(fields: Seq[DataType]) extends Projection { +object FromUnsafeProjection { - def this(schema: StructType) = this(schema.fields.map(_.dataType)) - - private[this] val expressions = fields.zipWithIndex.map { case (dt, idx) => - new BoundReference(idx, dt, true) + /** + * Returns an Projection for given StructType. + */ + def apply(schema: StructType): Projection = { + apply(schema.fields.map(_.dataType)) } - @transient private[this] lazy val generatedProj = - GenerateMutableProjection.generate(expressions)() + /** + * Returns an UnsafeProjection for given Array of DataTypes. + */ + def apply(fields: Seq[DataType]): Projection = { + create(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))) + } - override def apply(input: InternalRow): InternalRow = { - generatedProj(input) + /** + * Returns an Projection for given sequence of Expressions (bounded). + */ + private def create(exprs: Seq[Expression]): Projection = { + GenerateSafeProjection.generate(exprs) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index fc7cfee989..3177e6b750 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -127,6 +127,8 @@ class CodeGenContext { dataType match { case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" + // The UTF8String may came from UnsafeRow, otherwise clone is cheap (re-use the bytes) + case StringType => s"$row.update($ordinal, $value.clone())" case _ => s"$row.update($ordinal, $value)" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala new file mode 100644 index 0000000000..f06ffc5449 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp +import org.apache.spark.sql.types.{StringType, StructType, DataType} + + +/** + * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new + * input [[InternalRow]] for a fixed set of [[Expression Expressions]]. + */ +object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] { + + protected def canonicalize(in: Seq[Expression]): Seq[Expression] = + in.map(ExpressionCanonicalizer.execute) + + protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = + in.map(BindReferences.bindReference(_, inputSchema)) + + private def genUpdater( + ctx: CodeGenContext, + setter: String, + dataType: DataType, + ordinal: Int, + value: String): String = { + dataType match { + case struct: StructType => + val rowTerm = ctx.freshName("row") + val updates = struct.map(_.dataType).zipWithIndex.map { case (dt, i) => + val colTerm = ctx.freshName("col") + s""" + if ($value.isNullAt($i)) { + $rowTerm.setNullAt($i); + } else { + ${ctx.javaType(dt)} $colTerm = ${ctx.getValue(value, dt, s"$i")}; + ${genUpdater(ctx, rowTerm, dt, i, colTerm)}; + } + """ + }.mkString("\n") + s""" + $genericMutableRowType $rowTerm = new $genericMutableRowType(${struct.fields.length}); + $updates + $setter.update($ordinal, $rowTerm.copy()); + """ + case _ => + ctx.setColumn(setter, dataType, ordinal, value) + } + } + + protected def create(expressions: Seq[Expression]): Projection = { + val ctx = newCodeGenContext() + val projectionCode = expressions.zipWithIndex.map { + case (NoOp, _) => "" + case (e, i) => + val evaluationCode = e.gen(ctx) + evaluationCode.code + + s""" + if (${evaluationCode.isNull}) { + mutableRow.setNullAt($i); + } else { + ${genUpdater(ctx, "mutableRow", e.dataType, i, evaluationCode.primitive)}; + } + """ + } + // collect projections into blocks as function has 64kb codesize limit in JVM + val projectionBlocks = new ArrayBuffer[String]() + val blockBuilder = new StringBuilder() + for (projection <- projectionCode) { + if (blockBuilder.length > 16 * 1000) { + projectionBlocks.append(blockBuilder.toString()) + blockBuilder.clear() + } + blockBuilder.append(projection) + } + projectionBlocks.append(blockBuilder.toString()) + + val (projectionFuns, projectionCalls) = { + // inline it if we have only one block + if (projectionBlocks.length == 1) { + ("", projectionBlocks.head) + } else { + ( + projectionBlocks.zipWithIndex.map { case (body, i) => + s""" + |private void apply$i(InternalRow i) { + | $body + |} + """.stripMargin + }.mkString, + projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n") + ) + } + } + + val code = s""" + public Object generate($exprType[] expr) { + return new SpecificSafeProjection(expr); + } + + class SpecificSafeProjection extends ${classOf[BaseProjection].getName} { + + private $exprType[] expressions; + private $mutableRowType mutableRow; + ${declareMutableStates(ctx)} + + public SpecificSafeProjection($exprType[] expr) { + expressions = expr; + mutableRow = new $genericMutableRowType(${expressions.size}); + ${initMutableStates(ctx)} + } + + $projectionFuns + + public Object apply(Object _i) { + InternalRow i = (InternalRow) _i; + $projectionCalls + + return mutableRow; + } + } + """ + + logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") + + val c = compile(code) + c.generate(ctx.references.toArray).asInstanceOf[Projection] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 80c64e5689..56225290cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -971,12 +971,12 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio override def nullSafeEval(s: Any, p: Any, r: Any): Any = { if (!p.equals(lastRegex)) { // regex value changed - lastRegex = p.asInstanceOf[UTF8String] + lastRegex = p.asInstanceOf[UTF8String].clone() pattern = Pattern.compile(lastRegex.toString) } if (!r.equals(lastReplacementInUTF8)) { // replacement string changed - lastReplacementInUTF8 = r.asInstanceOf[UTF8String] + lastReplacementInUTF8 = r.asInstanceOf[UTF8String].clone() lastReplacement = lastReplacementInUTF8.toString } val m = pattern.matcher(s.toString()) @@ -1022,12 +1022,12 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio s""" if (!$regexp.equals(${termLastRegex})) { // regex value changed - ${termLastRegex} = $regexp; + ${termLastRegex} = $regexp.clone(); ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); } if (!$rep.equals(${termLastReplacementInUTF8})) { // replacement string changed - ${termLastReplacementInUTF8} = $rep; + ${termLastReplacementInUTF8} = $rep.clone(); ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); } ${termResult}.delete(0, ${termResult}.length()); @@ -1061,7 +1061,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio override def nullSafeEval(s: Any, p: Any, r: Any): Any = { if (!p.equals(lastRegex)) { // regex value changed - lastRegex = p.asInstanceOf[UTF8String] + lastRegex = p.asInstanceOf[UTF8String].clone() pattern = Pattern.compile(lastRegex.toString) } val m = pattern.matcher(s.toString()) @@ -1090,7 +1090,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio s""" if (!$regexp.equals(${termLastRegex})) { // regex value changed - ${termLastRegex} = $regexp; + ${termLastRegex} = $regexp.clone(); ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); } java.util.regex.Matcher m = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index f4fbc49677..cc82f7c3f5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.expressions import scala.math._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.{Row, RandomDataGenerator} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.types.{DataTypeTestUtils, NullType, StructField, StructType} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Additional tests for code generation. @@ -93,4 +94,44 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } + + test("test generated safe and unsafe projection") { + val schema = new StructType(Array( + StructField("a", StringType, true), + StructField("b", IntegerType, true), + StructField("c", new StructType(Array( + StructField("aa", StringType, true), + StructField("bb", IntegerType, true) + )), true), + StructField("d", new StructType(Array( + StructField("a", new StructType(Array( + StructField("b", StringType, true), + StructField("", IntegerType, true) + )), true) + )), true) + )) + val row = Row("a", 1, Row("b", 2), Row(Row("c", 3))) + val lit = Literal.create(row, schema) + val internalRow = lit.value.asInstanceOf[InternalRow] + + val unsafeProj = UnsafeProjection.create(schema) + val unsafeRow: UnsafeRow = unsafeProj(internalRow) + assert(unsafeRow.getUTF8String(0) === UTF8String.fromString("a")) + assert(unsafeRow.getInt(1) === 1) + assert(unsafeRow.getStruct(2, 2).getUTF8String(0) === UTF8String.fromString("b")) + assert(unsafeRow.getStruct(2, 2).getInt(1) === 2) + assert(unsafeRow.getStruct(3, 1).getStruct(0, 2).getUTF8String(0) === + UTF8String.fromString("c")) + assert(unsafeRow.getStruct(3, 1).getStruct(0, 2).getInt(1) === 3) + + val fromUnsafe = FromUnsafeProjection(schema) + val internalRow2 = fromUnsafe(unsafeRow) + assert(internalRow === internalRow2) + + // update unsafeRow should not affect internalRow2 + unsafeRow.setInt(1, 10) + unsafeRow.getStruct(2, 2).setInt(1, 10) + unsafeRow.getStruct(3, 1).getStruct(0, 2).setInt(1, 4) + assert(internalRow === internalRow2) + } } |