From 106c0789d8c83c7081bc9a335df78ba728e95872 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 8 Aug 2015 08:33:14 -0700 Subject: [SPARK-9738] [SQL] remove FromUnsafe and add its codegen version to GenerateSafe In https://github.com/apache/spark/pull/7752 we added `FromUnsafe` to convert nexted unsafe data like array/map/struct to safe versions. It's a quick solution and we already have `GenerateSafe` to do the conversion which is codegened. So we should remove `FromUnsafe` and implement its codegen version in `GenerateSafe`. Author: Wenchen Fan Closes #8029 from cloud-fan/from-unsafe and squashes the following commits: ed40d8f [Wenchen Fan] add the copy back a93fd4b [Wenchen Fan] cogengen FromUnsafe --- .../sql/catalyst/expressions/FromUnsafe.scala | 70 ------------ .../sql/catalyst/expressions/Projection.scala | 8 +- .../codegen/GenerateSafeProjection.scala | 120 ++++++++++++++++----- .../sql/execution/RowFormatConvertersSuite.scala | 4 +- 4 files changed, 95 insertions(+), 107 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala deleted file mode 100644 index 9b960b136f..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -case class FromUnsafe(child: Expression) extends UnaryExpression - with ExpectsInputTypes with CodegenFallback { - - override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(ArrayType, StructType, MapType)) - - override def dataType: DataType = child.dataType - - private def convert(value: Any, dt: DataType): Any = dt match { - case StructType(fields) => - val row = value.asInstanceOf[UnsafeRow] - val result = new Array[Any](fields.length) - fields.map(_.dataType).zipWithIndex.foreach { case (dt, i) => - if (!row.isNullAt(i)) { - result(i) = convert(row.get(i, dt), dt) - } - } - new GenericInternalRow(result) - - case ArrayType(elementType, _) => - val array = value.asInstanceOf[UnsafeArrayData] - val length = array.numElements() - val result = new Array[Any](length) - var i = 0 - while (i < length) { - if (!array.isNullAt(i)) { - result(i) = convert(array.get(i, elementType), elementType) - } - i += 1 - } - new GenericArrayData(result) - - case StringType => value.asInstanceOf[UTF8String].clone() - - case MapType(kt, vt, _) => - val map = value.asInstanceOf[UnsafeMapData] - val safeKeyArray = convert(map.keys, ArrayType(kt)).asInstanceOf[GenericArrayData] - val safeValueArray = convert(map.values, ArrayType(vt)).asInstanceOf[GenericArrayData] - new ArrayBasedMapData(safeKeyArray, safeValueArray) - - case _ => value - } - - override def nullSafeEval(input: Any): Any = { - convert(input, dataType) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 796bc327a3..afe52e6a66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -152,13 +152,7 @@ object FromUnsafeProjection { */ def apply(fields: Seq[DataType]): Projection = { create(fields.zipWithIndex.map(x => { - val b = new BoundReference(x._2, x._1, true) - // todo: this is quite slow, maybe remove this whole projection after remove generic getter of - // InternalRow? - b.dataType match { - case _: StructType | _: ArrayType | _: MapType => FromUnsafe(b) - case _ => b - } + new BoundReference(x._2, x._1, true) })) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index f06ffc5449..ef08ddf041 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp -import org.apache.spark.sql.types.{StringType, StructType, DataType} +import org.apache.spark.sql.types._ /** @@ -36,34 +36,94 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = in.map(BindReferences.bindReference(_, inputSchema)) - private def genUpdater( + private def createCodeForStruct( ctx: CodeGenContext, - setter: String, - dataType: DataType, - ordinal: Int, - value: String): String = { - dataType match { - case struct: StructType => - val rowTerm = ctx.freshName("row") - val updates = struct.map(_.dataType).zipWithIndex.map { case (dt, i) => - val colTerm = ctx.freshName("col") - s""" - if ($value.isNullAt($i)) { - $rowTerm.setNullAt($i); - } else { - ${ctx.javaType(dt)} $colTerm = ${ctx.getValue(value, dt, s"$i")}; - ${genUpdater(ctx, rowTerm, dt, i, colTerm)}; - } - """ - }.mkString("\n") - s""" - $genericMutableRowType $rowTerm = new $genericMutableRowType(${struct.fields.length}); - $updates - $setter.update($ordinal, $rowTerm.copy()); - """ - case _ => - ctx.setColumn(setter, dataType, ordinal, value) - } + input: String, + schema: StructType): GeneratedExpressionCode = { + val tmp = ctx.freshName("tmp") + val output = ctx.freshName("safeRow") + val values = ctx.freshName("values") + val rowClass = classOf[GenericInternalRow].getName + + val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => + val converter = convertToSafe(ctx, ctx.getValue(tmp, dt, i.toString), dt) + s""" + if (!$tmp.isNullAt($i)) { + ${converter.code} + $values[$i] = ${converter.primitive}; + } + """ + }.mkString("\n") + + val code = s""" + final InternalRow $tmp = $input; + final Object[] $values = new Object[${schema.length}]; + $fieldWriters + final InternalRow $output = new $rowClass($values); + """ + + GeneratedExpressionCode(code, "false", output) + } + + private def createCodeForArray( + ctx: CodeGenContext, + input: String, + elementType: DataType): GeneratedExpressionCode = { + val tmp = ctx.freshName("tmp") + val output = ctx.freshName("safeArray") + val values = ctx.freshName("values") + val numElements = ctx.freshName("numElements") + val index = ctx.freshName("index") + val arrayClass = classOf[GenericArrayData].getName + + val elementConverter = convertToSafe(ctx, ctx.getValue(tmp, elementType, index), elementType) + val code = s""" + final ArrayData $tmp = $input; + final int $numElements = $tmp.numElements(); + final Object[] $values = new Object[$numElements]; + for (int $index = 0; $index < $numElements; $index++) { + if (!$tmp.isNullAt($index)) { + ${elementConverter.code} + $values[$index] = ${elementConverter.primitive}; + } + } + final ArrayData $output = new $arrayClass($values); + """ + + GeneratedExpressionCode(code, "false", output) + } + + private def createCodeForMap( + ctx: CodeGenContext, + input: String, + keyType: DataType, + valueType: DataType): GeneratedExpressionCode = { + val tmp = ctx.freshName("tmp") + val output = ctx.freshName("safeMap") + val mapClass = classOf[ArrayBasedMapData].getName + + val keyConverter = createCodeForArray(ctx, s"$tmp.keyArray()", keyType) + val valueConverter = createCodeForArray(ctx, s"$tmp.valueArray()", valueType) + val code = s""" + final MapData $tmp = $input; + ${keyConverter.code} + ${valueConverter.code} + final MapData $output = new $mapClass(${keyConverter.primitive}, ${valueConverter.primitive}); + """ + + GeneratedExpressionCode(code, "false", output) + } + + private def convertToSafe( + ctx: CodeGenContext, + input: String, + dataType: DataType): GeneratedExpressionCode = dataType match { + case s: StructType => createCodeForStruct(ctx, input, s) + case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) + case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) + // UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe. + case StringType => GeneratedExpressionCode("", "false", s"$input.clone()") + case _ => GeneratedExpressionCode("", "false", input) } protected def create(expressions: Seq[Expression]): Projection = { @@ -72,12 +132,14 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] case (NoOp, _) => "" case (e, i) => val evaluationCode = e.gen(ctx) + val converter = convertToSafe(ctx, evaluationCode.primitive, e.dataType) evaluationCode.code + s""" if (${evaluationCode.isNull}) { mutableRow.setNullAt($i); } else { - ${genUpdater(ctx, "mutableRow", e.dataType, i, evaluationCode.primitive)}; + ${converter.code} + ${ctx.setColumn("mutableRow", e.dataType, i, converter.primitive)}; } """ } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 322966f423..dd08e9025a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -112,7 +112,9 @@ case class DummyPlan(child: SparkPlan) extends UnaryNode { override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitions { iter => - // cache all strings to make sure we have deep copied UTF8String inside incoming + // This `DummyPlan` is in safe mode, so we don't need to do copy even we hold some + // values gotten from the incoming rows. + // we cache all strings here to make sure we have deep copied UTF8String inside incoming // safe InternalRow. val strings = new scala.collection.mutable.ArrayBuffer[UTF8String] iter.foreach { row => -- cgit v1.2.3