aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-08-01 21:50:42 -0700
committerDavies Liu <davies.liu@gmail.com>2015-08-01 21:50:42 -0700
commit57084e0c7c318912208ee31c52d61c14eeddd8f4 (patch)
tree0e43e0f2a10008b8f279d95099a6cc38e72afb7d /sql
parentc1b0cbd762d78bedca0ab564cf9ca0970b7b99d2 (diff)
downloadspark-57084e0c7c318912208ee31c52d61c14eeddd8f4.tar.gz
spark-57084e0c7c318912208ee31c52d61c14eeddd8f4.tar.bz2
spark-57084e0c7c318912208ee31c52d61c14eeddd8f4.zip
[SPARK-9459] [SQL] use generated FromUnsafeProjection to do deep copy for UTF8String and struct
When accessing a column in UnsafeRow, it's good to avoid the copy, then we should do deep copy when turn the UnsafeRow into generic Row, this PR brings generated FromUnsafeProjection to do that. This PR also fix the expressions that cache the UTF8String, which should also copy it. Author: Davies Liu <davies@databricks.com> Closes #7840 from davies/avoid_copy and squashes the following commits: 230c8a1 [Davies Liu] address comment fd797c9 [Davies Liu] Merge branch 'master' of github.com:apache/spark into avoid_copy e095dd0 [Davies Liu] rollback rename 8ef5b0b [Davies Liu] copy String in Columnar 81360b8 [Davies Liu] fix class name 9aecb88 [Davies Liu] use FromUnsafeProjection to do deep copy for UTF8String and struct
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala33
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala147
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala12
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala45
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala10
7 files changed, 231 insertions, 26 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 5a19aa8920..1b475b2492 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -363,15 +363,19 @@ public final class UnsafeRow extends MutableRow {
@Override
public UTF8String getUTF8String(int ordinal) {
assertIndexIsValid(ordinal);
- return isNullAt(ordinal) ? null : UTF8String.fromBytes(getBinary(ordinal));
+ if (isNullAt(ordinal)) return null;
+ final long offsetAndSize = getLong(ordinal);
+ final int offset = (int) (offsetAndSize >> 32);
+ final int size = (int) (offsetAndSize & ((1L << 32) - 1));
+ return UTF8String.fromAddress(baseObject, baseOffset + offset, size);
}
@Override
public byte[] getBinary(int ordinal) {
+ assertIndexIsValid(ordinal);
if (isNullAt(ordinal)) {
return null;
} else {
- assertIndexIsValid(ordinal);
final long offsetAndSize = getLong(ordinal);
final int offset = (int) (offsetAndSize >> 32);
final int size = (int) (offsetAndSize & ((1L << 32) - 1));
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 000be70f17..83129dc12d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -18,8 +18,8 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateMutableProjection}
-import org.apache.spark.sql.types._
+import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
+import org.apache.spark.sql.types.{DataType, Decimal, StructType, _}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
/**
@@ -114,8 +114,7 @@ object UnsafeProjection {
* Returns an UnsafeProjection for given Array of DataTypes.
*/
def create(fields: Array[DataType]): UnsafeProjection = {
- val exprs = fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))
- create(exprs)
+ create(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)))
}
/**
@@ -139,19 +138,27 @@ object UnsafeProjection {
/**
* A projection that could turn UnsafeRow into GenericInternalRow
*/
-case class FromUnsafeProjection(fields: Seq[DataType]) extends Projection {
+object FromUnsafeProjection {
- def this(schema: StructType) = this(schema.fields.map(_.dataType))
-
- private[this] val expressions = fields.zipWithIndex.map { case (dt, idx) =>
- new BoundReference(idx, dt, true)
+ /**
+ * Returns an Projection for given StructType.
+ */
+ def apply(schema: StructType): Projection = {
+ apply(schema.fields.map(_.dataType))
}
- @transient private[this] lazy val generatedProj =
- GenerateMutableProjection.generate(expressions)()
+ /**
+ * Returns an UnsafeProjection for given Array of DataTypes.
+ */
+ def apply(fields: Seq[DataType]): Projection = {
+ create(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)))
+ }
- override def apply(input: InternalRow): InternalRow = {
- generatedProj(input)
+ /**
+ * Returns an Projection for given sequence of Expressions (bounded).
+ */
+ private def create(exprs: Seq[Expression]): Projection = {
+ GenerateSafeProjection.generate(exprs)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index fc7cfee989..3177e6b750 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -127,6 +127,8 @@ class CodeGenContext {
dataType match {
case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)"
case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})"
+ // The UTF8String may came from UnsafeRow, otherwise clone is cheap (re-use the bytes)
+ case StringType => s"$row.update($ordinal, $value.clone())"
case _ => s"$row.update($ordinal, $value)"
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
new file mode 100644
index 0000000000..f06ffc5449
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
+import org.apache.spark.sql.types.{StringType, StructType, DataType}
+
+
+/**
+ * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new
+ * input [[InternalRow]] for a fixed set of [[Expression Expressions]].
+ */
+object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] {
+
+ protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
+ in.map(ExpressionCanonicalizer.execute)
+
+ protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
+ in.map(BindReferences.bindReference(_, inputSchema))
+
+ private def genUpdater(
+ ctx: CodeGenContext,
+ setter: String,
+ dataType: DataType,
+ ordinal: Int,
+ value: String): String = {
+ dataType match {
+ case struct: StructType =>
+ val rowTerm = ctx.freshName("row")
+ val updates = struct.map(_.dataType).zipWithIndex.map { case (dt, i) =>
+ val colTerm = ctx.freshName("col")
+ s"""
+ if ($value.isNullAt($i)) {
+ $rowTerm.setNullAt($i);
+ } else {
+ ${ctx.javaType(dt)} $colTerm = ${ctx.getValue(value, dt, s"$i")};
+ ${genUpdater(ctx, rowTerm, dt, i, colTerm)};
+ }
+ """
+ }.mkString("\n")
+ s"""
+ $genericMutableRowType $rowTerm = new $genericMutableRowType(${struct.fields.length});
+ $updates
+ $setter.update($ordinal, $rowTerm.copy());
+ """
+ case _ =>
+ ctx.setColumn(setter, dataType, ordinal, value)
+ }
+ }
+
+ protected def create(expressions: Seq[Expression]): Projection = {
+ val ctx = newCodeGenContext()
+ val projectionCode = expressions.zipWithIndex.map {
+ case (NoOp, _) => ""
+ case (e, i) =>
+ val evaluationCode = e.gen(ctx)
+ evaluationCode.code +
+ s"""
+ if (${evaluationCode.isNull}) {
+ mutableRow.setNullAt($i);
+ } else {
+ ${genUpdater(ctx, "mutableRow", e.dataType, i, evaluationCode.primitive)};
+ }
+ """
+ }
+ // collect projections into blocks as function has 64kb codesize limit in JVM
+ val projectionBlocks = new ArrayBuffer[String]()
+ val blockBuilder = new StringBuilder()
+ for (projection <- projectionCode) {
+ if (blockBuilder.length > 16 * 1000) {
+ projectionBlocks.append(blockBuilder.toString())
+ blockBuilder.clear()
+ }
+ blockBuilder.append(projection)
+ }
+ projectionBlocks.append(blockBuilder.toString())
+
+ val (projectionFuns, projectionCalls) = {
+ // inline it if we have only one block
+ if (projectionBlocks.length == 1) {
+ ("", projectionBlocks.head)
+ } else {
+ (
+ projectionBlocks.zipWithIndex.map { case (body, i) =>
+ s"""
+ |private void apply$i(InternalRow i) {
+ | $body
+ |}
+ """.stripMargin
+ }.mkString,
+ projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n")
+ )
+ }
+ }
+
+ val code = s"""
+ public Object generate($exprType[] expr) {
+ return new SpecificSafeProjection(expr);
+ }
+
+ class SpecificSafeProjection extends ${classOf[BaseProjection].getName} {
+
+ private $exprType[] expressions;
+ private $mutableRowType mutableRow;
+ ${declareMutableStates(ctx)}
+
+ public SpecificSafeProjection($exprType[] expr) {
+ expressions = expr;
+ mutableRow = new $genericMutableRowType(${expressions.size});
+ ${initMutableStates(ctx)}
+ }
+
+ $projectionFuns
+
+ public Object apply(Object _i) {
+ InternalRow i = (InternalRow) _i;
+ $projectionCalls
+
+ return mutableRow;
+ }
+ }
+ """
+
+ logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}")
+
+ val c = compile(code)
+ c.generate(ctx.references.toArray).asInstanceOf[Projection]
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 80c64e5689..56225290cd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -971,12 +971,12 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
override def nullSafeEval(s: Any, p: Any, r: Any): Any = {
if (!p.equals(lastRegex)) {
// regex value changed
- lastRegex = p.asInstanceOf[UTF8String]
+ lastRegex = p.asInstanceOf[UTF8String].clone()
pattern = Pattern.compile(lastRegex.toString)
}
if (!r.equals(lastReplacementInUTF8)) {
// replacement string changed
- lastReplacementInUTF8 = r.asInstanceOf[UTF8String]
+ lastReplacementInUTF8 = r.asInstanceOf[UTF8String].clone()
lastReplacement = lastReplacementInUTF8.toString
}
val m = pattern.matcher(s.toString())
@@ -1022,12 +1022,12 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
s"""
if (!$regexp.equals(${termLastRegex})) {
// regex value changed
- ${termLastRegex} = $regexp;
+ ${termLastRegex} = $regexp.clone();
${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
}
if (!$rep.equals(${termLastReplacementInUTF8})) {
// replacement string changed
- ${termLastReplacementInUTF8} = $rep;
+ ${termLastReplacementInUTF8} = $rep.clone();
${termLastReplacement} = ${termLastReplacementInUTF8}.toString();
}
${termResult}.delete(0, ${termResult}.length());
@@ -1061,7 +1061,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
override def nullSafeEval(s: Any, p: Any, r: Any): Any = {
if (!p.equals(lastRegex)) {
// regex value changed
- lastRegex = p.asInstanceOf[UTF8String]
+ lastRegex = p.asInstanceOf[UTF8String].clone()
pattern = Pattern.compile(lastRegex.toString)
}
val m = pattern.matcher(s.toString())
@@ -1090,7 +1090,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
s"""
if (!$regexp.equals(${termLastRegex})) {
// regex value changed
- ${termLastRegex} = $regexp;
+ ${termLastRegex} = $regexp.clone();
${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
}
java.util.regex.Matcher m =
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index f4fbc49677..cc82f7c3f5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.expressions
import scala.math._
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.RandomDataGenerator
+import org.apache.spark.sql.{Row, RandomDataGenerator}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.types.{DataTypeTestUtils, NullType, StructField, StructType}
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
/**
* Additional tests for code generation.
@@ -93,4 +94,44 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
}
}
+
+ test("test generated safe and unsafe projection") {
+ val schema = new StructType(Array(
+ StructField("a", StringType, true),
+ StructField("b", IntegerType, true),
+ StructField("c", new StructType(Array(
+ StructField("aa", StringType, true),
+ StructField("bb", IntegerType, true)
+ )), true),
+ StructField("d", new StructType(Array(
+ StructField("a", new StructType(Array(
+ StructField("b", StringType, true),
+ StructField("", IntegerType, true)
+ )), true)
+ )), true)
+ ))
+ val row = Row("a", 1, Row("b", 2), Row(Row("c", 3)))
+ val lit = Literal.create(row, schema)
+ val internalRow = lit.value.asInstanceOf[InternalRow]
+
+ val unsafeProj = UnsafeProjection.create(schema)
+ val unsafeRow: UnsafeRow = unsafeProj(internalRow)
+ assert(unsafeRow.getUTF8String(0) === UTF8String.fromString("a"))
+ assert(unsafeRow.getInt(1) === 1)
+ assert(unsafeRow.getStruct(2, 2).getUTF8String(0) === UTF8String.fromString("b"))
+ assert(unsafeRow.getStruct(2, 2).getInt(1) === 2)
+ assert(unsafeRow.getStruct(3, 1).getStruct(0, 2).getUTF8String(0) ===
+ UTF8String.fromString("c"))
+ assert(unsafeRow.getStruct(3, 1).getStruct(0, 2).getInt(1) === 3)
+
+ val fromUnsafe = FromUnsafeProjection(schema)
+ val internalRow2 = fromUnsafe(unsafeRow)
+ assert(internalRow === internalRow2)
+
+ // update unsafeRow should not affect internalRow2
+ unsafeRow.setInt(1, 10)
+ unsafeRow.getStruct(2, 2).setInt(1, 10)
+ unsafeRow.getStruct(3, 1).getStruct(0, 2).setInt(1, 4)
+ assert(internalRow === internalRow2)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index 30f8fe320d..531a8244d5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -329,7 +329,7 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) {
}
override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = {
- row.update(ordinal, value)
+ row.update(ordinal, value.clone())
}
override def getField(row: InternalRow, ordinal: Int): UTF8String = {
@@ -337,7 +337,7 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) {
}
override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) {
- to.update(toOrdinal, from.getUTF8String(fromOrdinal))
+ setField(to, toOrdinal, getField(from, fromOrdinal))
}
}
@@ -396,7 +396,11 @@ private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int)
}
override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = {
- row(ordinal) = value
+ row.setDecimal(ordinal, value, precision)
+ }
+
+ override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) {
+ setField(to, toOrdinal, getField(from, fromOrdinal))
}
}