aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-07-17 01:27:14 -0700
committerReynold Xin <rxin@databricks.com>2015-07-17 01:27:14 -0700
commitec8973d1245d4a99edeb7365d7f4b0063ac31ddf (patch)
treefbda393df0fdde5f1b088c6882561401f77bdc3f /sql
parent5a3c1ad087cb645a9496349ca021168e479ffae9 (diff)
downloadspark-ec8973d1245d4a99edeb7365d7f4b0063ac31ddf.tar.gz
spark-ec8973d1245d4a99edeb7365d7f4b0063ac31ddf.tar.bz2
spark-ec8973d1245d4a99edeb7365d7f4b0063ac31ddf.zip
[SPARK-9022] [SQL] Generated projections for UnsafeRow
Added two projections: GenerateUnsafeProjection and FromUnsafeProjection, which could be used to convert UnsafeRow from/to GenericInternalRow. They will re-use the buffer during projection, similar to MutableProjection (without all the interface MutableProjection has). cc rxin JoshRosen Author: Davies Liu <davies@databricks.com> Closes #7437 from davies/unsafe_proj2 and squashes the following commits: dbf538e [Davies Liu] test with all the expression (only for supported types) dc737b2 [Davies Liu] address comment e424520 [Davies Liu] fix scala style 70e231c [Davies Liu] address comments 729138d [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_proj2 5a26373 [Davies Liu] unsafe projections
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java27
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala35
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala69
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala15
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala125
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala17
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala34
11 files changed, 266 insertions, 72 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index b94601cf6d..d1d81c87bb 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -28,13 +28,11 @@ import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.sql.AbstractScalaRowIterator;
import org.apache.spark.sql.catalyst.InternalRow;
-import org.apache.spark.sql.catalyst.expressions.ObjectUnsafeColumnWriter;
import org.apache.spark.sql.catalyst.expressions.UnsafeColumnWriter;
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
-import org.apache.spark.sql.catalyst.expressions.UnsafeRowConverter;
import org.apache.spark.sql.catalyst.util.ObjectPool;
-import org.apache.spark.sql.types.StructField;
-import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.util.collection.unsafe.sort.PrefixComparator;
import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
@@ -52,10 +50,9 @@ final class UnsafeExternalRowSorter {
private long numRowsInserted = 0;
private final StructType schema;
- private final UnsafeRowConverter rowConverter;
+ private final UnsafeProjection unsafeProjection;
private final PrefixComputer prefixComputer;
private final UnsafeExternalSorter sorter;
- private byte[] rowConversionBuffer = new byte[1024 * 8];
public static abstract class PrefixComputer {
abstract long computePrefix(InternalRow row);
@@ -67,7 +64,7 @@ final class UnsafeExternalRowSorter {
PrefixComparator prefixComparator,
PrefixComputer prefixComputer) throws IOException {
this.schema = schema;
- this.rowConverter = new UnsafeRowConverter(schema);
+ this.unsafeProjection = UnsafeProjection.create(schema);
this.prefixComputer = prefixComputer;
final SparkEnv sparkEnv = SparkEnv.get();
final TaskContext taskContext = TaskContext.get();
@@ -94,18 +91,12 @@ final class UnsafeExternalRowSorter {
@VisibleForTesting
void insertRow(InternalRow row) throws IOException {
- final int sizeRequirement = rowConverter.getSizeRequirement(row);
- if (sizeRequirement > rowConversionBuffer.length) {
- rowConversionBuffer = new byte[sizeRequirement];
- }
- final int bytesWritten = rowConverter.writeRow(
- row, rowConversionBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, sizeRequirement, null);
- assert (bytesWritten == sizeRequirement);
+ UnsafeRow unsafeRow = unsafeProjection.apply(row);
final long prefix = prefixComputer.computePrefix(row);
sorter.insertRecord(
- rowConversionBuffer,
- PlatformDependent.BYTE_ARRAY_OFFSET,
- sizeRequirement,
+ unsafeRow.getBaseObject(),
+ unsafeRow.getBaseOffset(),
+ unsafeRow.getSizeInBytes(),
prefix
);
numRowsInserted++;
@@ -186,7 +177,7 @@ final class UnsafeExternalRowSorter {
public static boolean supportsSchema(StructType schema) {
// TODO: add spilling note to explain why we do this for now:
for (StructField field : schema.fields()) {
- if (UnsafeColumnWriter.forType(field.dataType()) instanceof ObjectUnsafeColumnWriter) {
+ if (!UnsafeColumnWriter.canEmbed(field.dataType())) {
return false;
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 65ae87fe6d..692b9fddbb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -424,20 +424,20 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case (BinaryType, StringType) =>
defineCodeGen (ctx, ev, c =>
- s"${ctx.stringType}.fromBytes($c)")
+ s"UTF8String.fromBytes($c)")
case (DateType, StringType) =>
defineCodeGen(ctx, ev, c =>
- s"""${ctx.stringType}.fromString(
+ s"""UTF8String.fromString(
org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c))""")
case (TimestampType, StringType) =>
defineCodeGen(ctx, ev, c =>
- s"""${ctx.stringType}.fromString(
+ s"""UTF8String.fromString(
org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c))""")
case (_, StringType) =>
- defineCodeGen(ctx, ev, c => s"${ctx.stringType}.fromString(String.valueOf($c))")
+ defineCodeGen(ctx, ev, c => s"UTF8String.fromString(String.valueOf($c))")
case (StringType, IntervalType) =>
defineCodeGen(ctx, ev, c =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index bf47a6c75b..24b01ea551 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -18,6 +18,8 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateMutableProjection}
+import org.apache.spark.sql.types.{StructType, DataType}
/**
* A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions.
@@ -74,6 +76,39 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu
}
/**
+ * A projection that returns UnsafeRow.
+ */
+abstract class UnsafeProjection extends Projection {
+ override def apply(row: InternalRow): UnsafeRow
+}
+
+object UnsafeProjection {
+ def create(schema: StructType): UnsafeProjection = create(schema.fields.map(_.dataType))
+
+ def create(fields: Seq[DataType]): UnsafeProjection = {
+ val exprs = fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))
+ GenerateUnsafeProjection.generate(exprs)
+ }
+}
+
+/**
+ * A projection that could turn UnsafeRow into GenericInternalRow
+ */
+case class FromUnsafeProjection(fields: Seq[DataType]) extends Projection {
+
+ private[this] val expressions = fields.zipWithIndex.map { case (dt, idx) =>
+ new BoundReference(idx, dt, true)
+ }
+
+ @transient private[this] lazy val generatedProj =
+ GenerateMutableProjection.generate(expressions)()
+
+ override def apply(input: InternalRow): InternalRow = {
+ generatedProj(input)
+ }
+}
+
+/**
* A mutable wrapper that makes two rows appear as a single concatenated row. Designed to
* be instantiated once per thread and reused.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
index 6af5e6200e..885ab091fc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
@@ -147,77 +147,73 @@ private object UnsafeColumnWriter {
case t => ObjectUnsafeColumnWriter
}
}
+
+ /**
+ * Returns whether the dataType can be embedded into UnsafeRow (not using ObjectPool).
+ */
+ def canEmbed(dataType: DataType): Boolean = {
+ forType(dataType) != ObjectUnsafeColumnWriter
+ }
}
// ------------------------------------------------------------------------------------------------
-private object NullUnsafeColumnWriter extends NullUnsafeColumnWriter
-private object BooleanUnsafeColumnWriter extends BooleanUnsafeColumnWriter
-private object ByteUnsafeColumnWriter extends ByteUnsafeColumnWriter
-private object ShortUnsafeColumnWriter extends ShortUnsafeColumnWriter
-private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter
-private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter
-private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter
-private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter
-private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter
-private object BinaryUnsafeColumnWriter extends BinaryUnsafeColumnWriter
-private object ObjectUnsafeColumnWriter extends ObjectUnsafeColumnWriter
private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter {
// Primitives don't write to the variable-length region:
def getSize(sourceRow: InternalRow, column: Int): Int = 0
}
-private class NullUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
+private object NullUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter {
override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
target.setNullAt(column)
0
}
}
-private class BooleanUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
+private object BooleanUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter {
override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
target.setBoolean(column, source.getBoolean(column))
0
}
}
-private class ByteUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
+private object ByteUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter {
override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
target.setByte(column, source.getByte(column))
0
}
}
-private class ShortUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
+private object ShortUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter {
override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
target.setShort(column, source.getShort(column))
0
}
}
-private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
+private object IntUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter {
override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
target.setInt(column, source.getInt(column))
0
}
}
-private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
+private object LongUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter {
override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
target.setLong(column, source.getLong(column))
0
}
}
-private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
+private object FloatUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter {
override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
target.setFloat(column, source.getFloat(column))
0
}
}
-private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
+private object DoubleUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter {
override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
target.setDouble(column, source.getDouble(column))
0
@@ -226,18 +222,21 @@ private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWr
private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter {
- def getBytes(source: InternalRow, column: Int): Array[Byte]
+ protected[this] def isString: Boolean
+ protected[this] def getBytes(source: InternalRow, column: Int): Array[Byte]
- def getSize(source: InternalRow, column: Int): Int = {
+ override def getSize(source: InternalRow, column: Int): Int = {
val numBytes = getBytes(source, column).length
ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
}
- protected[this] def isString: Boolean
-
override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
- val offset = target.getBaseOffset + cursor
val bytes = getBytes(source, column)
+ write(target, bytes, column, cursor)
+ }
+
+ def write(target: UnsafeRow, bytes: Array[Byte], column: Int, cursor: Int): Int = {
+ val offset = target.getBaseOffset + cursor
val numBytes = bytes.length
if ((numBytes & 0x07) > 0) {
// zero-out the padding bytes
@@ -256,22 +255,32 @@ private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter {
}
}
-private class StringUnsafeColumnWriter private() extends BytesUnsafeColumnWriter {
+private object StringUnsafeColumnWriter extends BytesUnsafeColumnWriter {
protected[this] def isString: Boolean = true
def getBytes(source: InternalRow, column: Int): Array[Byte] = {
source.getAs[UTF8String](column).getBytes
}
+ // TODO(davies): refactor this
+ // specialized for codegen
+ def getSize(value: UTF8String): Int =
+ ByteArrayMethods.roundNumberOfBytesToNearestWord(value.numBytes())
+ def write(target: UnsafeRow, value: UTF8String, column: Int, cursor: Int): Int = {
+ write(target, value.getBytes, column, cursor)
+ }
}
-private class BinaryUnsafeColumnWriter private() extends BytesUnsafeColumnWriter {
- protected[this] def isString: Boolean = false
- def getBytes(source: InternalRow, column: Int): Array[Byte] = {
+private object BinaryUnsafeColumnWriter extends BytesUnsafeColumnWriter {
+ protected[this] override def isString: Boolean = false
+ override def getBytes(source: InternalRow, column: Int): Array[Byte] = {
source.getAs[Array[Byte]](column)
}
+ // specialized for codegen
+ def getSize(value: Array[Byte]): Int =
+ ByteArrayMethods.roundNumberOfBytesToNearestWord(value.length)
}
-private class ObjectUnsafeColumnWriter private() extends UnsafeColumnWriter {
- def getSize(sourceRow: InternalRow, column: Int): Int = 0
+private object ObjectUnsafeColumnWriter extends UnsafeColumnWriter {
+ override def getSize(sourceRow: InternalRow, column: Int): Int = 0
override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
val obj = source.get(column)
val idx = target.getPool.put(obj)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 328d635de8..45dc146488 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -24,6 +24,7 @@ import com.google.common.cache.{CacheBuilder, CacheLoader}
import org.codehaus.janino.ClassBodyEvaluator
import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -68,9 +69,6 @@ class CodeGenContext {
mutableStates += ((javaType, variableName, initialValue))
}
- val stringType: String = classOf[UTF8String].getName
- val decimalType: String = classOf[Decimal].getName
-
final val JAVA_BOOLEAN = "boolean"
final val JAVA_BYTE = "byte"
final val JAVA_SHORT = "short"
@@ -136,9 +134,9 @@ class CodeGenContext {
case LongType | TimestampType => JAVA_LONG
case FloatType => JAVA_FLOAT
case DoubleType => JAVA_DOUBLE
- case dt: DecimalType => decimalType
+ case dt: DecimalType => "Decimal"
case BinaryType => "byte[]"
- case StringType => stringType
+ case StringType => "UTF8String"
case _: StructType => "InternalRow"
case _: ArrayType => s"scala.collection.Seq"
case _: MapType => s"scala.collection.Map"
@@ -262,7 +260,12 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
private[this] def doCompile(code: String): GeneratedClass = {
val evaluator = new ClassBodyEvaluator()
evaluator.setParentClassLoader(getClass.getClassLoader)
- evaluator.setDefaultImports(Array("org.apache.spark.sql.catalyst.InternalRow"))
+ evaluator.setDefaultImports(Array(
+ classOf[InternalRow].getName,
+ classOf[UnsafeRow].getName,
+ classOf[UTF8String].getName,
+ classOf[Decimal].getName
+ ))
evaluator.setExtendedClass(classOf[GeneratedClass])
try {
evaluator.cook(code)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index 3e5ca308dc..8f9fcbf810 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.types._
/**
* Java can not access Projection (in package object)
*/
-abstract class BaseProject extends Projection {}
+abstract class BaseProjection extends Projection {}
/**
* Generates bytecode that produces a new [[InternalRow]] object based on a fixed set of input
@@ -160,7 +160,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
return new SpecificProjection(expr);
}
- class SpecificProjection extends ${classOf[BaseProject].getName} {
+ class SpecificProjection extends ${classOf[BaseProjection].getName} {
private $exprType[] expressions = null;
$mutableStates
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
new file mode 100644
index 0000000000..a81d545a8e
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.{NullType, BinaryType, StringType}
+
+
+/**
+ * Generates a [[Projection]] that returns an [[UnsafeRow]].
+ *
+ * It generates the code for all the expressions, compute the total length for all the columns
+ * (can be accessed via variables), and then copy the data into a scratch buffer space in the
+ * form of UnsafeRow (the scratch buffer will grow as needed).
+ *
+ * Note: The returned UnsafeRow will be pointed to a scratch buffer inside the projection.
+ */
+object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] {
+
+ protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
+ in.map(ExpressionCanonicalizer.execute)
+
+ protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
+ in.map(BindReferences.bindReference(_, inputSchema))
+
+ protected def create(expressions: Seq[Expression]): UnsafeProjection = {
+ val ctx = newCodeGenContext()
+ val exprs = expressions.map(_.gen(ctx))
+ val allExprs = exprs.map(_.code).mkString("\n")
+ val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length)
+ val stringWriter = "org.apache.spark.sql.catalyst.expressions.StringUnsafeColumnWriter"
+ val binaryWriter = "org.apache.spark.sql.catalyst.expressions.BinaryUnsafeColumnWriter"
+ val additionalSize = expressions.zipWithIndex.map { case (e, i) =>
+ e.dataType match {
+ case StringType =>
+ s" + (${exprs(i).isNull} ? 0 : $stringWriter.getSize(${exprs(i).primitive}))"
+ case BinaryType =>
+ s" + (${exprs(i).isNull} ? 0 : $binaryWriter.getSize(${exprs(i).primitive}))"
+ case _ => ""
+ }
+ }.mkString("")
+
+ val writers = expressions.zipWithIndex.map { case (e, i) =>
+ val update = e.dataType match {
+ case dt if ctx.isPrimitiveType(dt) =>
+ s"${ctx.setColumn("target", dt, i, exprs(i).primitive)}"
+ case StringType =>
+ s"cursor += $stringWriter.write(target, ${exprs(i).primitive}, $i, cursor)"
+ case BinaryType =>
+ s"cursor += $binaryWriter.write(target, ${exprs(i).primitive}, $i, cursor)"
+ case NullType => ""
+ case _ =>
+ throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}")
+ }
+ s"""if (${exprs(i).isNull}) {
+ target.setNullAt($i);
+ } else {
+ $update;
+ }"""
+ }.mkString("\n ")
+
+ val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) =>
+ s"private $javaType $variableName = $initialValue;"
+ }.mkString("\n ")
+
+ val code = s"""
+ private $exprType[] expressions;
+
+ public Object generate($exprType[] expr) {
+ this.expressions = expr;
+ return new SpecificProjection();
+ }
+
+ class SpecificProjection extends ${classOf[UnsafeProjection].getName} {
+
+ private UnsafeRow target = new UnsafeRow();
+ private byte[] buffer = new byte[64];
+
+ $mutableStates
+
+ public SpecificProjection() {}
+
+ // Scala.Function1 need this
+ public Object apply(Object row) {
+ return apply((InternalRow) row);
+ }
+
+ public UnsafeRow apply(InternalRow i) {
+ ${allExprs}
+
+ // additionalSize had '+' in the beginning
+ int numBytes = $fixedSize $additionalSize;
+ if (numBytes > buffer.length) {
+ buffer = new byte[numBytes];
+ }
+ target.pointTo(buffer, org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET,
+ ${expressions.size}, numBytes, null);
+ int cursor = $fixedSize;
+ $writers
+ return target;
+ }
+ }
+ """
+
+ logDebug(s"code for ${expressions.mkString(",")}:\n$code")
+
+ val c = compile(code)
+ c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection]
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
index 2fa74b4ffc..b9d4736a65 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
@@ -54,7 +54,7 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, eval => {
s"""
- ${ev.primitive} = (new ${ctx.decimalType}()).setOrNull($eval, $precision, $scale);
+ ${ev.primitive} = (new Decimal()).setOrNull($eval, $precision, $scale);
${ev.isNull} = ${ev.primitive} == null;
"""
})
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index a7ad452ef4..84b289c4d1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -263,7 +263,7 @@ case class Bin(child: Expression)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, (c) =>
- s"${ctx.stringType}.fromString(java.lang.Long.toBinaryString($c))")
+ s"UTF8String.fromString(java.lang.Long.toBinaryString($c))")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index a269ec4a1e..8d8d66ddeb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -17,12 +17,11 @@
package org.apache.spark.sql.catalyst.expressions
-import java.security.MessageDigest
-import java.security.NoSuchAlgorithmException
+import java.security.{MessageDigest, NoSuchAlgorithmException}
import java.util.zip.CRC32
import org.apache.commons.codec.digest.DigestUtils
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -42,7 +41,7 @@ case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInput
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, c =>
- s"${ctx.stringType}.fromString(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))")
+ s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))")
}
}
@@ -93,19 +92,19 @@ case class Sha2(left: Expression, right: Expression)
try {
java.security.MessageDigest md = java.security.MessageDigest.getInstance("SHA-224");
md.update($eval1);
- ${ev.primitive} = ${ctx.stringType}.fromBytes(md.digest());
+ ${ev.primitive} = UTF8String.fromBytes(md.digest());
} catch (java.security.NoSuchAlgorithmException e) {
${ev.isNull} = true;
}
} else if ($eval2 == 256 || $eval2 == 0) {
${ev.primitive} =
- ${ctx.stringType}.fromString($digestUtils.sha256Hex($eval1));
+ UTF8String.fromString($digestUtils.sha256Hex($eval1));
} else if ($eval2 == 384) {
${ev.primitive} =
- ${ctx.stringType}.fromString($digestUtils.sha384Hex($eval1));
+ UTF8String.fromString($digestUtils.sha384Hex($eval1));
} else if ($eval2 == 512) {
${ev.primitive} =
- ${ctx.stringType}.fromString($digestUtils.sha512Hex($eval1));
+ UTF8String.fromString($digestUtils.sha512Hex($eval1));
} else {
${ev.isNull} = true;
}
@@ -129,7 +128,7 @@ case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInpu
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, c =>
- s"${ctx.stringType}.fromString(org.apache.commons.codec.digest.DigestUtils.shaHex($c))"
+ s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.shaHex($c))"
)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index 43392df4be..c43486b3dd 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -23,7 +23,7 @@ import org.scalatest.Matchers._
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.CatalystTypeConverters
-import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection}
+import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateProjection, GenerateMutableProjection}
import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
@@ -43,6 +43,9 @@ trait ExpressionEvalHelper {
checkEvaluationWithoutCodegen(expression, catalystValue, inputRow)
checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow)
checkEvaluationWithGeneratedProjection(expression, catalystValue, inputRow)
+ if (UnsafeColumnWriter.canEmbed(expression.dataType)) {
+ checkEvalutionWithUnsafeProjection(expression, catalystValue, inputRow)
+ }
checkEvaluationWithOptimization(expression, catalystValue, inputRow)
}
@@ -142,6 +145,35 @@ trait ExpressionEvalHelper {
}
}
+ protected def checkEvalutionWithUnsafeProjection(
+ expression: Expression,
+ expected: Any,
+ inputRow: InternalRow = EmptyRow): Unit = {
+ val ctx = GenerateUnsafeProjection.newCodeGenContext()
+ lazy val evaluated = expression.gen(ctx)
+
+ val plan = try {
+ GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)
+ } catch {
+ case e: Throwable =>
+ fail(
+ s"""
+ |Code generation of $expression failed:
+ |${evaluated.code}
+ |$e
+ """.stripMargin)
+ }
+
+ val unsafeRow = plan(inputRow)
+ // UnsafeRow cannot be compared with GenericInternalRow directly
+ val actual = FromUnsafeProjection(expression.dataType :: Nil)(unsafeRow)
+ val expectedRow = InternalRow(expected)
+ if (actual != expectedRow) {
+ val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
+ fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
+ }
+ }
+
protected def checkEvaluationWithOptimization(
expression: Expression,
expected: Any,