aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-08-08 08:33:14 -0700
committerDavies Liu <davies.liu@gmail.com>2015-08-08 08:33:14 -0700
commit106c0789d8c83c7081bc9a335df78ba728e95872 (patch)
tree9938bb7cd1465e110c0ef1dcfe4ca7ab7f19dce1 /sql
parent11caf1ce290b6931647c2f71268f847d1d48930e (diff)
downloadspark-106c0789d8c83c7081bc9a335df78ba728e95872.tar.gz
spark-106c0789d8c83c7081bc9a335df78ba728e95872.tar.bz2
spark-106c0789d8c83c7081bc9a335df78ba728e95872.zip
[SPARK-9738] [SQL] remove FromUnsafe and add its codegen version to GenerateSafe
In https://github.com/apache/spark/pull/7752 we added `FromUnsafe` to convert nexted unsafe data like array/map/struct to safe versions. It's a quick solution and we already have `GenerateSafe` to do the conversion which is codegened. So we should remove `FromUnsafe` and implement its codegen version in `GenerateSafe`. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #8029 from cloud-fan/from-unsafe and squashes the following commits: ed40d8f [Wenchen Fan] add the copy back a93fd4b [Wenchen Fan] cogengen FromUnsafe
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala70
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala120
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala4
4 files changed, 95 insertions, 107 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala
deleted file mode 100644
index 9b960b136f..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala
+++ /dev/null
@@ -1,70 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
-
-case class FromUnsafe(child: Expression) extends UnaryExpression
- with ExpectsInputTypes with CodegenFallback {
-
- override def inputTypes: Seq[AbstractDataType] =
- Seq(TypeCollection(ArrayType, StructType, MapType))
-
- override def dataType: DataType = child.dataType
-
- private def convert(value: Any, dt: DataType): Any = dt match {
- case StructType(fields) =>
- val row = value.asInstanceOf[UnsafeRow]
- val result = new Array[Any](fields.length)
- fields.map(_.dataType).zipWithIndex.foreach { case (dt, i) =>
- if (!row.isNullAt(i)) {
- result(i) = convert(row.get(i, dt), dt)
- }
- }
- new GenericInternalRow(result)
-
- case ArrayType(elementType, _) =>
- val array = value.asInstanceOf[UnsafeArrayData]
- val length = array.numElements()
- val result = new Array[Any](length)
- var i = 0
- while (i < length) {
- if (!array.isNullAt(i)) {
- result(i) = convert(array.get(i, elementType), elementType)
- }
- i += 1
- }
- new GenericArrayData(result)
-
- case StringType => value.asInstanceOf[UTF8String].clone()
-
- case MapType(kt, vt, _) =>
- val map = value.asInstanceOf[UnsafeMapData]
- val safeKeyArray = convert(map.keys, ArrayType(kt)).asInstanceOf[GenericArrayData]
- val safeValueArray = convert(map.values, ArrayType(vt)).asInstanceOf[GenericArrayData]
- new ArrayBasedMapData(safeKeyArray, safeValueArray)
-
- case _ => value
- }
-
- override def nullSafeEval(input: Any): Any = {
- convert(input, dataType)
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 796bc327a3..afe52e6a66 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -152,13 +152,7 @@ object FromUnsafeProjection {
*/
def apply(fields: Seq[DataType]): Projection = {
create(fields.zipWithIndex.map(x => {
- val b = new BoundReference(x._2, x._1, true)
- // todo: this is quite slow, maybe remove this whole projection after remove generic getter of
- // InternalRow?
- b.dataType match {
- case _: StructType | _: ArrayType | _: MapType => FromUnsafe(b)
- case _ => b
- }
+ new BoundReference(x._2, x._1, true)
}))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index f06ffc5449..ef08ddf041 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
-import org.apache.spark.sql.types.{StringType, StructType, DataType}
+import org.apache.spark.sql.types._
/**
@@ -36,34 +36,94 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
in.map(BindReferences.bindReference(_, inputSchema))
- private def genUpdater(
+ private def createCodeForStruct(
ctx: CodeGenContext,
- setter: String,
- dataType: DataType,
- ordinal: Int,
- value: String): String = {
- dataType match {
- case struct: StructType =>
- val rowTerm = ctx.freshName("row")
- val updates = struct.map(_.dataType).zipWithIndex.map { case (dt, i) =>
- val colTerm = ctx.freshName("col")
- s"""
- if ($value.isNullAt($i)) {
- $rowTerm.setNullAt($i);
- } else {
- ${ctx.javaType(dt)} $colTerm = ${ctx.getValue(value, dt, s"$i")};
- ${genUpdater(ctx, rowTerm, dt, i, colTerm)};
- }
- """
- }.mkString("\n")
- s"""
- $genericMutableRowType $rowTerm = new $genericMutableRowType(${struct.fields.length});
- $updates
- $setter.update($ordinal, $rowTerm.copy());
- """
- case _ =>
- ctx.setColumn(setter, dataType, ordinal, value)
- }
+ input: String,
+ schema: StructType): GeneratedExpressionCode = {
+ val tmp = ctx.freshName("tmp")
+ val output = ctx.freshName("safeRow")
+ val values = ctx.freshName("values")
+ val rowClass = classOf[GenericInternalRow].getName
+
+ val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) =>
+ val converter = convertToSafe(ctx, ctx.getValue(tmp, dt, i.toString), dt)
+ s"""
+ if (!$tmp.isNullAt($i)) {
+ ${converter.code}
+ $values[$i] = ${converter.primitive};
+ }
+ """
+ }.mkString("\n")
+
+ val code = s"""
+ final InternalRow $tmp = $input;
+ final Object[] $values = new Object[${schema.length}];
+ $fieldWriters
+ final InternalRow $output = new $rowClass($values);
+ """
+
+ GeneratedExpressionCode(code, "false", output)
+ }
+
+ private def createCodeForArray(
+ ctx: CodeGenContext,
+ input: String,
+ elementType: DataType): GeneratedExpressionCode = {
+ val tmp = ctx.freshName("tmp")
+ val output = ctx.freshName("safeArray")
+ val values = ctx.freshName("values")
+ val numElements = ctx.freshName("numElements")
+ val index = ctx.freshName("index")
+ val arrayClass = classOf[GenericArrayData].getName
+
+ val elementConverter = convertToSafe(ctx, ctx.getValue(tmp, elementType, index), elementType)
+ val code = s"""
+ final ArrayData $tmp = $input;
+ final int $numElements = $tmp.numElements();
+ final Object[] $values = new Object[$numElements];
+ for (int $index = 0; $index < $numElements; $index++) {
+ if (!$tmp.isNullAt($index)) {
+ ${elementConverter.code}
+ $values[$index] = ${elementConverter.primitive};
+ }
+ }
+ final ArrayData $output = new $arrayClass($values);
+ """
+
+ GeneratedExpressionCode(code, "false", output)
+ }
+
+ private def createCodeForMap(
+ ctx: CodeGenContext,
+ input: String,
+ keyType: DataType,
+ valueType: DataType): GeneratedExpressionCode = {
+ val tmp = ctx.freshName("tmp")
+ val output = ctx.freshName("safeMap")
+ val mapClass = classOf[ArrayBasedMapData].getName
+
+ val keyConverter = createCodeForArray(ctx, s"$tmp.keyArray()", keyType)
+ val valueConverter = createCodeForArray(ctx, s"$tmp.valueArray()", valueType)
+ val code = s"""
+ final MapData $tmp = $input;
+ ${keyConverter.code}
+ ${valueConverter.code}
+ final MapData $output = new $mapClass(${keyConverter.primitive}, ${valueConverter.primitive});
+ """
+
+ GeneratedExpressionCode(code, "false", output)
+ }
+
+ private def convertToSafe(
+ ctx: CodeGenContext,
+ input: String,
+ dataType: DataType): GeneratedExpressionCode = dataType match {
+ case s: StructType => createCodeForStruct(ctx, input, s)
+ case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType)
+ case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType)
+ // UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe.
+ case StringType => GeneratedExpressionCode("", "false", s"$input.clone()")
+ case _ => GeneratedExpressionCode("", "false", input)
}
protected def create(expressions: Seq[Expression]): Projection = {
@@ -72,12 +132,14 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
case (NoOp, _) => ""
case (e, i) =>
val evaluationCode = e.gen(ctx)
+ val converter = convertToSafe(ctx, evaluationCode.primitive, e.dataType)
evaluationCode.code +
s"""
if (${evaluationCode.isNull}) {
mutableRow.setNullAt($i);
} else {
- ${genUpdater(ctx, "mutableRow", e.dataType, i, evaluationCode.primitive)};
+ ${converter.code}
+ ${ctx.setColumn("mutableRow", e.dataType, i, converter.primitive)};
}
"""
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
index 322966f423..dd08e9025a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
@@ -112,7 +112,9 @@ case class DummyPlan(child: SparkPlan) extends UnaryNode {
override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitions { iter =>
- // cache all strings to make sure we have deep copied UTF8String inside incoming
+ // This `DummyPlan` is in safe mode, so we don't need to do copy even we hold some
+ // values gotten from the incoming rows.
+ // we cache all strings here to make sure we have deep copied UTF8String inside incoming
// safe InternalRow.
val strings = new scala.collection.mutable.ArrayBuffer[UTF8String]
iter.foreach { row =>