aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/main/scala
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-08-01 21:50:42 -0700
committerDavies Liu <davies.liu@gmail.com>2015-08-01 21:50:42 -0700
commit57084e0c7c318912208ee31c52d61c14eeddd8f4 (patch)
tree0e43e0f2a10008b8f279d95099a6cc38e72afb7d /sql/catalyst/src/main/scala
parentc1b0cbd762d78bedca0ab564cf9ca0970b7b99d2 (diff)
downloadspark-57084e0c7c318912208ee31c52d61c14eeddd8f4.tar.gz
spark-57084e0c7c318912208ee31c52d61c14eeddd8f4.tar.bz2
spark-57084e0c7c318912208ee31c52d61c14eeddd8f4.zip
[SPARK-9459] [SQL] use generated FromUnsafeProjection to do deep copy for UTF8String and struct
When accessing a column in UnsafeRow, it's good to avoid the copy, then we should do deep copy when turn the UnsafeRow into generic Row, this PR brings generated FromUnsafeProjection to do that. This PR also fix the expressions that cache the UTF8String, which should also copy it. Author: Davies Liu <davies@databricks.com> Closes #7840 from davies/avoid_copy and squashes the following commits: 230c8a1 [Davies Liu] address comment fd797c9 [Davies Liu] Merge branch 'master' of github.com:apache/spark into avoid_copy e095dd0 [Davies Liu] rollback rename 8ef5b0b [Davies Liu] copy String in Columnar 81360b8 [Davies Liu] fix class name 9aecb88 [Davies Liu] use FromUnsafeProjection to do deep copy for UTF8String and struct
Diffstat (limited to 'sql/catalyst/src/main/scala')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala33
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala147
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala12
4 files changed, 175 insertions, 19 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 000be70f17..83129dc12d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -18,8 +18,8 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateMutableProjection}
-import org.apache.spark.sql.types._
+import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
+import org.apache.spark.sql.types.{DataType, Decimal, StructType, _}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
/**
@@ -114,8 +114,7 @@ object UnsafeProjection {
* Returns an UnsafeProjection for given Array of DataTypes.
*/
def create(fields: Array[DataType]): UnsafeProjection = {
- val exprs = fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))
- create(exprs)
+ create(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)))
}
/**
@@ -139,19 +138,27 @@ object UnsafeProjection {
/**
* A projection that could turn UnsafeRow into GenericInternalRow
*/
-case class FromUnsafeProjection(fields: Seq[DataType]) extends Projection {
+object FromUnsafeProjection {
- def this(schema: StructType) = this(schema.fields.map(_.dataType))
-
- private[this] val expressions = fields.zipWithIndex.map { case (dt, idx) =>
- new BoundReference(idx, dt, true)
+ /**
+ * Returns an Projection for given StructType.
+ */
+ def apply(schema: StructType): Projection = {
+ apply(schema.fields.map(_.dataType))
}
- @transient private[this] lazy val generatedProj =
- GenerateMutableProjection.generate(expressions)()
+ /**
+ * Returns an UnsafeProjection for given Array of DataTypes.
+ */
+ def apply(fields: Seq[DataType]): Projection = {
+ create(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)))
+ }
- override def apply(input: InternalRow): InternalRow = {
- generatedProj(input)
+ /**
+ * Returns an Projection for given sequence of Expressions (bounded).
+ */
+ private def create(exprs: Seq[Expression]): Projection = {
+ GenerateSafeProjection.generate(exprs)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index fc7cfee989..3177e6b750 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -127,6 +127,8 @@ class CodeGenContext {
dataType match {
case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)"
case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})"
+ // The UTF8String may came from UnsafeRow, otherwise clone is cheap (re-use the bytes)
+ case StringType => s"$row.update($ordinal, $value.clone())"
case _ => s"$row.update($ordinal, $value)"
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
new file mode 100644
index 0000000000..f06ffc5449
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
+import org.apache.spark.sql.types.{StringType, StructType, DataType}
+
+
+/**
+ * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new
+ * input [[InternalRow]] for a fixed set of [[Expression Expressions]].
+ */
+object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] {
+
+ protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
+ in.map(ExpressionCanonicalizer.execute)
+
+ protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
+ in.map(BindReferences.bindReference(_, inputSchema))
+
+ private def genUpdater(
+ ctx: CodeGenContext,
+ setter: String,
+ dataType: DataType,
+ ordinal: Int,
+ value: String): String = {
+ dataType match {
+ case struct: StructType =>
+ val rowTerm = ctx.freshName("row")
+ val updates = struct.map(_.dataType).zipWithIndex.map { case (dt, i) =>
+ val colTerm = ctx.freshName("col")
+ s"""
+ if ($value.isNullAt($i)) {
+ $rowTerm.setNullAt($i);
+ } else {
+ ${ctx.javaType(dt)} $colTerm = ${ctx.getValue(value, dt, s"$i")};
+ ${genUpdater(ctx, rowTerm, dt, i, colTerm)};
+ }
+ """
+ }.mkString("\n")
+ s"""
+ $genericMutableRowType $rowTerm = new $genericMutableRowType(${struct.fields.length});
+ $updates
+ $setter.update($ordinal, $rowTerm.copy());
+ """
+ case _ =>
+ ctx.setColumn(setter, dataType, ordinal, value)
+ }
+ }
+
+ protected def create(expressions: Seq[Expression]): Projection = {
+ val ctx = newCodeGenContext()
+ val projectionCode = expressions.zipWithIndex.map {
+ case (NoOp, _) => ""
+ case (e, i) =>
+ val evaluationCode = e.gen(ctx)
+ evaluationCode.code +
+ s"""
+ if (${evaluationCode.isNull}) {
+ mutableRow.setNullAt($i);
+ } else {
+ ${genUpdater(ctx, "mutableRow", e.dataType, i, evaluationCode.primitive)};
+ }
+ """
+ }
+ // collect projections into blocks as function has 64kb codesize limit in JVM
+ val projectionBlocks = new ArrayBuffer[String]()
+ val blockBuilder = new StringBuilder()
+ for (projection <- projectionCode) {
+ if (blockBuilder.length > 16 * 1000) {
+ projectionBlocks.append(blockBuilder.toString())
+ blockBuilder.clear()
+ }
+ blockBuilder.append(projection)
+ }
+ projectionBlocks.append(blockBuilder.toString())
+
+ val (projectionFuns, projectionCalls) = {
+ // inline it if we have only one block
+ if (projectionBlocks.length == 1) {
+ ("", projectionBlocks.head)
+ } else {
+ (
+ projectionBlocks.zipWithIndex.map { case (body, i) =>
+ s"""
+ |private void apply$i(InternalRow i) {
+ | $body
+ |}
+ """.stripMargin
+ }.mkString,
+ projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n")
+ )
+ }
+ }
+
+ val code = s"""
+ public Object generate($exprType[] expr) {
+ return new SpecificSafeProjection(expr);
+ }
+
+ class SpecificSafeProjection extends ${classOf[BaseProjection].getName} {
+
+ private $exprType[] expressions;
+ private $mutableRowType mutableRow;
+ ${declareMutableStates(ctx)}
+
+ public SpecificSafeProjection($exprType[] expr) {
+ expressions = expr;
+ mutableRow = new $genericMutableRowType(${expressions.size});
+ ${initMutableStates(ctx)}
+ }
+
+ $projectionFuns
+
+ public Object apply(Object _i) {
+ InternalRow i = (InternalRow) _i;
+ $projectionCalls
+
+ return mutableRow;
+ }
+ }
+ """
+
+ logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}")
+
+ val c = compile(code)
+ c.generate(ctx.references.toArray).asInstanceOf[Projection]
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 80c64e5689..56225290cd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -971,12 +971,12 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
override def nullSafeEval(s: Any, p: Any, r: Any): Any = {
if (!p.equals(lastRegex)) {
// regex value changed
- lastRegex = p.asInstanceOf[UTF8String]
+ lastRegex = p.asInstanceOf[UTF8String].clone()
pattern = Pattern.compile(lastRegex.toString)
}
if (!r.equals(lastReplacementInUTF8)) {
// replacement string changed
- lastReplacementInUTF8 = r.asInstanceOf[UTF8String]
+ lastReplacementInUTF8 = r.asInstanceOf[UTF8String].clone()
lastReplacement = lastReplacementInUTF8.toString
}
val m = pattern.matcher(s.toString())
@@ -1022,12 +1022,12 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
s"""
if (!$regexp.equals(${termLastRegex})) {
// regex value changed
- ${termLastRegex} = $regexp;
+ ${termLastRegex} = $regexp.clone();
${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
}
if (!$rep.equals(${termLastReplacementInUTF8})) {
// replacement string changed
- ${termLastReplacementInUTF8} = $rep;
+ ${termLastReplacementInUTF8} = $rep.clone();
${termLastReplacement} = ${termLastReplacementInUTF8}.toString();
}
${termResult}.delete(0, ${termResult}.length());
@@ -1061,7 +1061,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
override def nullSafeEval(s: Any, p: Any, r: Any): Any = {
if (!p.equals(lastRegex)) {
// regex value changed
- lastRegex = p.asInstanceOf[UTF8String]
+ lastRegex = p.asInstanceOf[UTF8String].clone()
pattern = Pattern.compile(lastRegex.toString)
}
val m = pattern.matcher(s.toString())
@@ -1090,7 +1090,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
s"""
if (!$regexp.equals(${termLastRegex})) {
// regex value changed
- ${termLastRegex} = $regexp;
+ ${termLastRegex} = $regexp.clone();
${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
}
java.util.regex.Matcher m =