aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala33
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala147
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala12
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala45
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala10
7 files changed, 231 insertions, 26 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 5a19aa8920..1b475b2492 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -363,15 +363,19 @@ public final class UnsafeRow extends MutableRow {
@Override
public UTF8String getUTF8String(int ordinal) {
assertIndexIsValid(ordinal);
- return isNullAt(ordinal) ? null : UTF8String.fromBytes(getBinary(ordinal));
+ if (isNullAt(ordinal)) return null;
+ final long offsetAndSize = getLong(ordinal);
+ final int offset = (int) (offsetAndSize >> 32);
+ final int size = (int) (offsetAndSize & ((1L << 32) - 1));
+ return UTF8String.fromAddress(baseObject, baseOffset + offset, size);
}
@Override
public byte[] getBinary(int ordinal) {
+ assertIndexIsValid(ordinal);
if (isNullAt(ordinal)) {
return null;
} else {
- assertIndexIsValid(ordinal);
final long offsetAndSize = getLong(ordinal);
final int offset = (int) (offsetAndSize >> 32);
final int size = (int) (offsetAndSize & ((1L << 32) - 1));
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 000be70f17..83129dc12d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -18,8 +18,8 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateMutableProjection}
-import org.apache.spark.sql.types._
+import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
+import org.apache.spark.sql.types.{DataType, Decimal, StructType, _}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
/**
@@ -114,8 +114,7 @@ object UnsafeProjection {
* Returns an UnsafeProjection for given Array of DataTypes.
*/
def create(fields: Array[DataType]): UnsafeProjection = {
- val exprs = fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))
- create(exprs)
+ create(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)))
}
/**
@@ -139,19 +138,27 @@ object UnsafeProjection {
/**
* A projection that could turn UnsafeRow into GenericInternalRow
*/
-case class FromUnsafeProjection(fields: Seq[DataType]) extends Projection {
+object FromUnsafeProjection {
- def this(schema: StructType) = this(schema.fields.map(_.dataType))
-
- private[this] val expressions = fields.zipWithIndex.map { case (dt, idx) =>
- new BoundReference(idx, dt, true)
+ /**
+ * Returns an Projection for given StructType.
+ */
+ def apply(schema: StructType): Projection = {
+ apply(schema.fields.map(_.dataType))
}
- @transient private[this] lazy val generatedProj =
- GenerateMutableProjection.generate(expressions)()
+ /**
+ * Returns an UnsafeProjection for given Array of DataTypes.
+ */
+ def apply(fields: Seq[DataType]): Projection = {
+ create(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)))
+ }
- override def apply(input: InternalRow): InternalRow = {
- generatedProj(input)
+ /**
+ * Returns an Projection for given sequence of Expressions (bounded).
+ */
+ private def create(exprs: Seq[Expression]): Projection = {
+ GenerateSafeProjection.generate(exprs)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index fc7cfee989..3177e6b750 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -127,6 +127,8 @@ class CodeGenContext {
dataType match {
case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)"
case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})"
+ // The UTF8String may came from UnsafeRow, otherwise clone is cheap (re-use the bytes)
+ case StringType => s"$row.update($ordinal, $value.clone())"
case _ => s"$row.update($ordinal, $value)"
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
new file mode 100644
index 0000000000..f06ffc5449
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
+import org.apache.spark.sql.types.{StringType, StructType, DataType}
+
+
+/**
+ * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new
+ * input [[InternalRow]] for a fixed set of [[Expression Expressions]].
+ */
+object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] {
+
+ protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
+ in.map(ExpressionCanonicalizer.execute)
+
+ protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
+ in.map(BindReferences.bindReference(_, inputSchema))
+
+ private def genUpdater(
+ ctx: CodeGenContext,
+ setter: String,
+ dataType: DataType,
+ ordinal: Int,
+ value: String): String = {
+ dataType match {
+ case struct: StructType =>
+ val rowTerm = ctx.freshName("row")
+ val updates = struct.map(_.dataType).zipWithIndex.map { case (dt, i) =>
+ val colTerm = ctx.freshName("col")
+ s"""
+ if ($value.isNullAt($i)) {
+ $rowTerm.setNullAt($i);
+ } else {
+ ${ctx.javaType(dt)} $colTerm = ${ctx.getValue(value, dt, s"$i")};
+ ${genUpdater(ctx, rowTerm, dt, i, colTerm)};
+ }
+ """
+ }.mkString("\n")
+ s"""
+ $genericMutableRowType $rowTerm = new $genericMutableRowType(${struct.fields.length});
+ $updates
+ $setter.update($ordinal, $rowTerm.copy());
+ """
+ case _ =>
+ ctx.setColumn(setter, dataType, ordinal, value)
+ }
+ }
+
+ protected def create(expressions: Seq[Expression]): Projection = {
+ val ctx = newCodeGenContext()
+ val projectionCode = expressions.zipWithIndex.map {
+ case (NoOp, _) => ""
+ case (e, i) =>
+ val evaluationCode = e.gen(ctx)
+ evaluationCode.code +
+ s"""
+ if (${evaluationCode.isNull}) {
+ mutableRow.setNullAt($i);
+ } else {
+ ${genUpdater(ctx, "mutableRow", e.dataType, i, evaluationCode.primitive)};
+ }
+ """
+ }
+ // collect projections into blocks as function has 64kb codesize limit in JVM
+ val projectionBlocks = new ArrayBuffer[String]()
+ val blockBuilder = new StringBuilder()
+ for (projection <- projectionCode) {
+ if (blockBuilder.length > 16 * 1000) {
+ projectionBlocks.append(blockBuilder.toString())
+ blockBuilder.clear()
+ }
+ blockBuilder.append(projection)
+ }
+ projectionBlocks.append(blockBuilder.toString())
+
+ val (projectionFuns, projectionCalls) = {
+ // inline it if we have only one block
+ if (projectionBlocks.length == 1) {
+ ("", projectionBlocks.head)
+ } else {
+ (
+ projectionBlocks.zipWithIndex.map { case (body, i) =>
+ s"""
+ |private void apply$i(InternalRow i) {
+ | $body
+ |}
+ """.stripMargin
+ }.mkString,
+ projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n")
+ )
+ }
+ }
+
+ val code = s"""
+ public Object generate($exprType[] expr) {
+ return new SpecificSafeProjection(expr);
+ }
+
+ class SpecificSafeProjection extends ${classOf[BaseProjection].getName} {
+
+ private $exprType[] expressions;
+ private $mutableRowType mutableRow;
+ ${declareMutableStates(ctx)}
+
+ public SpecificSafeProjection($exprType[] expr) {
+ expressions = expr;
+ mutableRow = new $genericMutableRowType(${expressions.size});
+ ${initMutableStates(ctx)}
+ }
+
+ $projectionFuns
+
+ public Object apply(Object _i) {
+ InternalRow i = (InternalRow) _i;
+ $projectionCalls
+
+ return mutableRow;
+ }
+ }
+ """
+
+ logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}")
+
+ val c = compile(code)
+ c.generate(ctx.references.toArray).asInstanceOf[Projection]
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 80c64e5689..56225290cd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -971,12 +971,12 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
override def nullSafeEval(s: Any, p: Any, r: Any): Any = {
if (!p.equals(lastRegex)) {
// regex value changed
- lastRegex = p.asInstanceOf[UTF8String]
+ lastRegex = p.asInstanceOf[UTF8String].clone()
pattern = Pattern.compile(lastRegex.toString)
}
if (!r.equals(lastReplacementInUTF8)) {
// replacement string changed
- lastReplacementInUTF8 = r.asInstanceOf[UTF8String]
+ lastReplacementInUTF8 = r.asInstanceOf[UTF8String].clone()
lastReplacement = lastReplacementInUTF8.toString
}
val m = pattern.matcher(s.toString())
@@ -1022,12 +1022,12 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
s"""
if (!$regexp.equals(${termLastRegex})) {
// regex value changed
- ${termLastRegex} = $regexp;
+ ${termLastRegex} = $regexp.clone();
${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
}
if (!$rep.equals(${termLastReplacementInUTF8})) {
// replacement string changed
- ${termLastReplacementInUTF8} = $rep;
+ ${termLastReplacementInUTF8} = $rep.clone();
${termLastReplacement} = ${termLastReplacementInUTF8}.toString();
}
${termResult}.delete(0, ${termResult}.length());
@@ -1061,7 +1061,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
override def nullSafeEval(s: Any, p: Any, r: Any): Any = {
if (!p.equals(lastRegex)) {
// regex value changed
- lastRegex = p.asInstanceOf[UTF8String]
+ lastRegex = p.asInstanceOf[UTF8String].clone()
pattern = Pattern.compile(lastRegex.toString)
}
val m = pattern.matcher(s.toString())
@@ -1090,7 +1090,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
s"""
if (!$regexp.equals(${termLastRegex})) {
// regex value changed
- ${termLastRegex} = $regexp;
+ ${termLastRegex} = $regexp.clone();
${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
}
java.util.regex.Matcher m =
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index f4fbc49677..cc82f7c3f5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.expressions
import scala.math._
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.RandomDataGenerator
+import org.apache.spark.sql.{Row, RandomDataGenerator}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.types.{DataTypeTestUtils, NullType, StructField, StructType}
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
/**
* Additional tests for code generation.
@@ -93,4 +94,44 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
}
}
+
+ test("test generated safe and unsafe projection") {
+ val schema = new StructType(Array(
+ StructField("a", StringType, true),
+ StructField("b", IntegerType, true),
+ StructField("c", new StructType(Array(
+ StructField("aa", StringType, true),
+ StructField("bb", IntegerType, true)
+ )), true),
+ StructField("d", new StructType(Array(
+ StructField("a", new StructType(Array(
+ StructField("b", StringType, true),
+ StructField("", IntegerType, true)
+ )), true)
+ )), true)
+ ))
+ val row = Row("a", 1, Row("b", 2), Row(Row("c", 3)))
+ val lit = Literal.create(row, schema)
+ val internalRow = lit.value.asInstanceOf[InternalRow]
+
+ val unsafeProj = UnsafeProjection.create(schema)
+ val unsafeRow: UnsafeRow = unsafeProj(internalRow)
+ assert(unsafeRow.getUTF8String(0) === UTF8String.fromString("a"))
+ assert(unsafeRow.getInt(1) === 1)
+ assert(unsafeRow.getStruct(2, 2).getUTF8String(0) === UTF8String.fromString("b"))
+ assert(unsafeRow.getStruct(2, 2).getInt(1) === 2)
+ assert(unsafeRow.getStruct(3, 1).getStruct(0, 2).getUTF8String(0) ===
+ UTF8String.fromString("c"))
+ assert(unsafeRow.getStruct(3, 1).getStruct(0, 2).getInt(1) === 3)
+
+ val fromUnsafe = FromUnsafeProjection(schema)
+ val internalRow2 = fromUnsafe(unsafeRow)
+ assert(internalRow === internalRow2)
+
+ // update unsafeRow should not affect internalRow2
+ unsafeRow.setInt(1, 10)
+ unsafeRow.getStruct(2, 2).setInt(1, 10)
+ unsafeRow.getStruct(3, 1).getStruct(0, 2).setInt(1, 4)
+ assert(internalRow === internalRow2)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index 30f8fe320d..531a8244d5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -329,7 +329,7 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) {
}
override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = {
- row.update(ordinal, value)
+ row.update(ordinal, value.clone())
}
override def getField(row: InternalRow, ordinal: Int): UTF8String = {
@@ -337,7 +337,7 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) {
}
override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) {
- to.update(toOrdinal, from.getUTF8String(fromOrdinal))
+ setField(to, toOrdinal, getField(from, fromOrdinal))
}
}
@@ -396,7 +396,11 @@ private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int)
}
override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = {
- row(ordinal) = value
+ row.setDecimal(ordinal, value, precision)
+ }
+
+ override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) {
+ setField(to, toOrdinal, getField(from, fromOrdinal))
}
}