diff options
Diffstat (limited to 'sql/catalyst')
6 files changed, 224 insertions, 23 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 5a19aa8920..1b475b2492 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -363,15 +363,19 @@ public final class UnsafeRow extends MutableRow { @Override public UTF8String getUTF8String(int ordinal) { assertIndexIsValid(ordinal); - return isNullAt(ordinal) ? null : UTF8String.fromBytes(getBinary(ordinal)); + if (isNullAt(ordinal)) return null; + final long offsetAndSize = getLong(ordinal); + final int offset = (int) (offsetAndSize >> 32); + final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + return UTF8String.fromAddress(baseObject, baseOffset + offset, size); } @Override public byte[] getBinary(int ordinal) { + assertIndexIsValid(ordinal); if (isNullAt(ordinal)) { return null; } else { - assertIndexIsValid(ordinal); final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) (offsetAndSize & ((1L << 32) - 1)); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 000be70f17..83129dc12d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateMutableProjection} -import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} +import org.apache.spark.sql.types.{DataType, Decimal, StructType, _} import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} /** @@ -114,8 +114,7 @@ object UnsafeProjection { * Returns an UnsafeProjection for given Array of DataTypes. */ def create(fields: Array[DataType]): UnsafeProjection = { - val exprs = fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)) - create(exprs) + create(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))) } /** @@ -139,19 +138,27 @@ object UnsafeProjection { /** * A projection that could turn UnsafeRow into GenericInternalRow */ -case class FromUnsafeProjection(fields: Seq[DataType]) extends Projection { +object FromUnsafeProjection { - def this(schema: StructType) = this(schema.fields.map(_.dataType)) - - private[this] val expressions = fields.zipWithIndex.map { case (dt, idx) => - new BoundReference(idx, dt, true) + /** + * Returns an Projection for given StructType. + */ + def apply(schema: StructType): Projection = { + apply(schema.fields.map(_.dataType)) } - @transient private[this] lazy val generatedProj = - GenerateMutableProjection.generate(expressions)() + /** + * Returns an UnsafeProjection for given Array of DataTypes. + */ + def apply(fields: Seq[DataType]): Projection = { + create(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))) + } - override def apply(input: InternalRow): InternalRow = { - generatedProj(input) + /** + * Returns an Projection for given sequence of Expressions (bounded). + */ + private def create(exprs: Seq[Expression]): Projection = { + GenerateSafeProjection.generate(exprs) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index fc7cfee989..3177e6b750 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -127,6 +127,8 @@ class CodeGenContext { dataType match { case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" + // The UTF8String may came from UnsafeRow, otherwise clone is cheap (re-use the bytes) + case StringType => s"$row.update($ordinal, $value.clone())" case _ => s"$row.update($ordinal, $value)" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala new file mode 100644 index 0000000000..f06ffc5449 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp +import org.apache.spark.sql.types.{StringType, StructType, DataType} + + +/** + * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new + * input [[InternalRow]] for a fixed set of [[Expression Expressions]]. + */ +object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] { + + protected def canonicalize(in: Seq[Expression]): Seq[Expression] = + in.map(ExpressionCanonicalizer.execute) + + protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = + in.map(BindReferences.bindReference(_, inputSchema)) + + private def genUpdater( + ctx: CodeGenContext, + setter: String, + dataType: DataType, + ordinal: Int, + value: String): String = { + dataType match { + case struct: StructType => + val rowTerm = ctx.freshName("row") + val updates = struct.map(_.dataType).zipWithIndex.map { case (dt, i) => + val colTerm = ctx.freshName("col") + s""" + if ($value.isNullAt($i)) { + $rowTerm.setNullAt($i); + } else { + ${ctx.javaType(dt)} $colTerm = ${ctx.getValue(value, dt, s"$i")}; + ${genUpdater(ctx, rowTerm, dt, i, colTerm)}; + } + """ + }.mkString("\n") + s""" + $genericMutableRowType $rowTerm = new $genericMutableRowType(${struct.fields.length}); + $updates + $setter.update($ordinal, $rowTerm.copy()); + """ + case _ => + ctx.setColumn(setter, dataType, ordinal, value) + } + } + + protected def create(expressions: Seq[Expression]): Projection = { + val ctx = newCodeGenContext() + val projectionCode = expressions.zipWithIndex.map { + case (NoOp, _) => "" + case (e, i) => + val evaluationCode = e.gen(ctx) + evaluationCode.code + + s""" + if (${evaluationCode.isNull}) { + mutableRow.setNullAt($i); + } else { + ${genUpdater(ctx, "mutableRow", e.dataType, i, evaluationCode.primitive)}; + } + """ + } + // collect projections into blocks as function has 64kb codesize limit in JVM + val projectionBlocks = new ArrayBuffer[String]() + val blockBuilder = new StringBuilder() + for (projection <- projectionCode) { + if (blockBuilder.length > 16 * 1000) { + projectionBlocks.append(blockBuilder.toString()) + blockBuilder.clear() + } + blockBuilder.append(projection) + } + projectionBlocks.append(blockBuilder.toString()) + + val (projectionFuns, projectionCalls) = { + // inline it if we have only one block + if (projectionBlocks.length == 1) { + ("", projectionBlocks.head) + } else { + ( + projectionBlocks.zipWithIndex.map { case (body, i) => + s""" + |private void apply$i(InternalRow i) { + | $body + |} + """.stripMargin + }.mkString, + projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n") + ) + } + } + + val code = s""" + public Object generate($exprType[] expr) { + return new SpecificSafeProjection(expr); + } + + class SpecificSafeProjection extends ${classOf[BaseProjection].getName} { + + private $exprType[] expressions; + private $mutableRowType mutableRow; + ${declareMutableStates(ctx)} + + public SpecificSafeProjection($exprType[] expr) { + expressions = expr; + mutableRow = new $genericMutableRowType(${expressions.size}); + ${initMutableStates(ctx)} + } + + $projectionFuns + + public Object apply(Object _i) { + InternalRow i = (InternalRow) _i; + $projectionCalls + + return mutableRow; + } + } + """ + + logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") + + val c = compile(code) + c.generate(ctx.references.toArray).asInstanceOf[Projection] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 80c64e5689..56225290cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -971,12 +971,12 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio override def nullSafeEval(s: Any, p: Any, r: Any): Any = { if (!p.equals(lastRegex)) { // regex value changed - lastRegex = p.asInstanceOf[UTF8String] + lastRegex = p.asInstanceOf[UTF8String].clone() pattern = Pattern.compile(lastRegex.toString) } if (!r.equals(lastReplacementInUTF8)) { // replacement string changed - lastReplacementInUTF8 = r.asInstanceOf[UTF8String] + lastReplacementInUTF8 = r.asInstanceOf[UTF8String].clone() lastReplacement = lastReplacementInUTF8.toString } val m = pattern.matcher(s.toString()) @@ -1022,12 +1022,12 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio s""" if (!$regexp.equals(${termLastRegex})) { // regex value changed - ${termLastRegex} = $regexp; + ${termLastRegex} = $regexp.clone(); ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); } if (!$rep.equals(${termLastReplacementInUTF8})) { // replacement string changed - ${termLastReplacementInUTF8} = $rep; + ${termLastReplacementInUTF8} = $rep.clone(); ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); } ${termResult}.delete(0, ${termResult}.length()); @@ -1061,7 +1061,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio override def nullSafeEval(s: Any, p: Any, r: Any): Any = { if (!p.equals(lastRegex)) { // regex value changed - lastRegex = p.asInstanceOf[UTF8String] + lastRegex = p.asInstanceOf[UTF8String].clone() pattern = Pattern.compile(lastRegex.toString) } val m = pattern.matcher(s.toString()) @@ -1090,7 +1090,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio s""" if (!$regexp.equals(${termLastRegex})) { // regex value changed - ${termLastRegex} = $regexp; + ${termLastRegex} = $regexp.clone(); ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); } java.util.regex.Matcher m = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index f4fbc49677..cc82f7c3f5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.expressions import scala.math._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.{Row, RandomDataGenerator} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.types.{DataTypeTestUtils, NullType, StructField, StructType} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Additional tests for code generation. @@ -93,4 +94,44 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } + + test("test generated safe and unsafe projection") { + val schema = new StructType(Array( + StructField("a", StringType, true), + StructField("b", IntegerType, true), + StructField("c", new StructType(Array( + StructField("aa", StringType, true), + StructField("bb", IntegerType, true) + )), true), + StructField("d", new StructType(Array( + StructField("a", new StructType(Array( + StructField("b", StringType, true), + StructField("", IntegerType, true) + )), true) + )), true) + )) + val row = Row("a", 1, Row("b", 2), Row(Row("c", 3))) + val lit = Literal.create(row, schema) + val internalRow = lit.value.asInstanceOf[InternalRow] + + val unsafeProj = UnsafeProjection.create(schema) + val unsafeRow: UnsafeRow = unsafeProj(internalRow) + assert(unsafeRow.getUTF8String(0) === UTF8String.fromString("a")) + assert(unsafeRow.getInt(1) === 1) + assert(unsafeRow.getStruct(2, 2).getUTF8String(0) === UTF8String.fromString("b")) + assert(unsafeRow.getStruct(2, 2).getInt(1) === 2) + assert(unsafeRow.getStruct(3, 1).getStruct(0, 2).getUTF8String(0) === + UTF8String.fromString("c")) + assert(unsafeRow.getStruct(3, 1).getStruct(0, 2).getInt(1) === 3) + + val fromUnsafe = FromUnsafeProjection(schema) + val internalRow2 = fromUnsafe(unsafeRow) + assert(internalRow === internalRow2) + + // update unsafeRow should not affect internalRow2 + unsafeRow.setInt(1, 10) + unsafeRow.getStruct(2, 2).setInt(1, 10) + unsafeRow.getStruct(3, 1).getStruct(0, 2).setInt(1, 4) + assert(internalRow === internalRow2) + } } |