diff options
Diffstat (limited to 'sql')
7 files changed, 258 insertions, 78 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index d26b1b187c..af61e2011f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -21,24 +21,40 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.unsafe.Platform; /** - * A helper class to manage the row buffer when construct unsafe rows. + * A helper class to manage the data buffer for an unsafe row. The data buffer can grow and + * automatically re-point the unsafe row to it. + * + * This class can be used to build a one-pass unsafe row writing program, i.e. data will be written + * to the data buffer directly and no extra copy is needed. There should be only one instance of + * this class per writing program, so that the memory segment/data buffer can be reused. Note that + * for each incoming record, we should call `reset` of BufferHolder instance before write the record + * and reuse the data buffer. + * + * Generally we should call `UnsafeRow.setTotalSize` and pass in `BufferHolder.totalSize` to update + * the size of the result row, after writing a record to the buffer. However, we can skip this step + * if the fields of row are all fixed-length, as the size of result row is also fixed. */ public class BufferHolder { public byte[] buffer; public int cursor = Platform.BYTE_ARRAY_OFFSET; + private final UnsafeRow row; + private final int fixedSize; - public BufferHolder() { - this(64); + public BufferHolder(UnsafeRow row) { + this(row, 64); } - public BufferHolder(int size) { - buffer = new byte[size]; + public BufferHolder(UnsafeRow row, int initialSize) { + this.fixedSize = UnsafeRow.calculateBitSetWidthInBytes(row.numFields()) + 8 * row.numFields(); + this.buffer = new byte[fixedSize + initialSize]; + this.row = row; + this.row.pointTo(buffer, buffer.length); } /** - * Grows the buffer to at least neededSize. If row is non-null, points the row to the buffer. + * Grows the buffer by at least neededSize and points the row to the buffer. */ - public void grow(int neededSize, UnsafeRow row) { + public void grow(int neededSize) { final int length = totalSize() + neededSize; if (buffer.length < length) { // This will not happen frequently, because the buffer is re-used. @@ -50,22 +66,12 @@ public class BufferHolder { Platform.BYTE_ARRAY_OFFSET, totalSize()); buffer = tmp; - if (row != null) { - row.pointTo(buffer, length * 2); - } + row.pointTo(buffer, buffer.length); } } - public void grow(int neededSize) { - grow(neededSize, null); - } - public void reset() { - cursor = Platform.BYTE_ARRAY_OFFSET; - } - public void resetTo(int offset) { - assert(offset <= buffer.length); - cursor = Platform.BYTE_ARRAY_OFFSET + offset; + cursor = Platform.BYTE_ARRAY_OFFSET + fixedSize; } public int totalSize() { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index e227c0dec9..4776617043 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -26,38 +26,56 @@ import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; /** - * A helper class to write data into global row buffer using `UnsafeRow` format, - * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}. + * A helper class to write data into global row buffer using `UnsafeRow` format. + * + * It will remember the offset of row buffer which it starts to write, and move the cursor of row + * buffer while writing. If new data(can be the input record if this is the outermost writer, or + * nested struct if this is an inner writer) comes, the starting cursor of row buffer may be + * changed, so we need to call `UnsafeRowWriter.reset` before writing, to update the + * `startingOffset` and clear out null bits. + * + * Note that if this is the outermost writer, which means we will always write from the very + * beginning of the global row buffer, we don't need to update `startingOffset` and can just call + * `zeroOutNullBytes` before writing new data. */ public class UnsafeRowWriter { - private BufferHolder holder; + private final BufferHolder holder; // The offset of the global buffer where we start to write this row. private int startingOffset; - private int nullBitsSize; - private UnsafeRow row; + private final int nullBitsSize; + private final int fixedSize; - public void initialize(BufferHolder holder, int numFields) { + public UnsafeRowWriter(BufferHolder holder, int numFields) { this.holder = holder; - this.startingOffset = holder.cursor; this.nullBitsSize = UnsafeRow.calculateBitSetWidthInBytes(numFields); + this.fixedSize = nullBitsSize + 8 * numFields; + this.startingOffset = holder.cursor; + } + + /** + * Resets the `startingOffset` according to the current cursor of row buffer, and clear out null + * bits. This should be called before we write a new nested struct to the row buffer. + */ + public void reset() { + this.startingOffset = holder.cursor; // grow the global buffer to make sure it has enough space to write fixed-length data. - final int fixedSize = nullBitsSize + 8 * numFields; - holder.grow(fixedSize, row); + holder.grow(fixedSize); holder.cursor += fixedSize; - // zero-out the null bits region + zeroOutNullBytes(); + } + + /** + * Clears out null bits. This should be called before we write a new row to row buffer. + */ + public void zeroOutNullBytes() { for (int i = 0; i < nullBitsSize; i += 8) { Platform.putLong(holder.buffer, startingOffset + i, 0L); } } - public void initialize(UnsafeRow row, BufferHolder holder, int numFields) { - initialize(holder, numFields); - this.row = row; - } - private void zeroOutPaddingBytes(int numBytes) { if ((numBytes & 0x07) > 0) { Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); @@ -98,7 +116,7 @@ public class UnsafeRowWriter { if (remainder > 0) { final int paddingBytes = 8 - remainder; - holder.grow(paddingBytes, row); + holder.grow(paddingBytes); for (int i = 0; i < paddingBytes; i++) { Platform.putByte(holder.buffer, holder.cursor, (byte) 0); @@ -161,7 +179,7 @@ public class UnsafeRowWriter { } } else { // grow the global buffer before writing data. - holder.grow(16, row); + holder.grow(16); // zero-out the bytes Platform.putLong(holder.buffer, holder.cursor, 0L); @@ -193,7 +211,7 @@ public class UnsafeRowWriter { final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); // grow the global buffer before writing data. - holder.grow(roundedSize, row); + holder.grow(roundedSize); zeroOutPaddingBytes(numBytes); @@ -214,7 +232,7 @@ public class UnsafeRowWriter { final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); // grow the global buffer before writing data. - holder.grow(roundedSize, row); + holder.grow(roundedSize); zeroOutPaddingBytes(numBytes); @@ -230,7 +248,7 @@ public class UnsafeRowWriter { public void write(int ordinal, CalendarInterval input) { // grow the global buffer before writing data. - holder.grow(16, row); + holder.grow(16); // Write the months and microseconds fields of Interval to the variable length portion. Platform.putLong(holder.buffer, holder.cursor, input.months); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 72bf39a039..6aa9cbf08b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -43,9 +43,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => false } - private val rowWriterClass = classOf[UnsafeRowWriter].getName - private val arrayWriterClass = classOf[UnsafeArrayWriter].getName - // TODO: if the nullability of field is correct, we can use it to save null check. private def writeStructToBuffer( ctx: CodegenContext, @@ -73,9 +70,27 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro row: String, inputs: Seq[ExprCode], inputTypes: Seq[DataType], - bufferHolder: String): String = { + bufferHolder: String, + isTopLevel: Boolean = false): String = { + val rowWriterClass = classOf[UnsafeRowWriter].getName val rowWriter = ctx.freshName("rowWriter") - ctx.addMutableState(rowWriterClass, rowWriter, s"this.$rowWriter = new $rowWriterClass();") + ctx.addMutableState(rowWriterClass, rowWriter, + s"this.$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});") + + val resetWriter = if (isTopLevel) { + // For top level row writer, it always writes to the beginning of the global buffer holder, + // which means its fixed-size region always in the same position, so we don't need to call + // `reset` to set up its fixed-size region every time. + if (inputs.map(_.isNull).forall(_ == "false")) { + // If all fields are not nullable, which means the null bits never changes, then we don't + // need to clear it out every time. + "" + } else { + s"$rowWriter.zeroOutNullBytes();" + } + } else { + s"$rowWriter.reset();" + } val writeFields = inputs.zip(inputTypes).zipWithIndex.map { case ((input, dataType), index) => @@ -122,11 +137,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor); """ - case _ if ctx.isPrimitiveType(dt) => - s""" - $rowWriter.write($index, ${input.value}); - """ - case t: DecimalType => s"$rowWriter.write($index, ${input.value}, ${t.precision}, ${t.scale});" @@ -153,7 +163,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } s""" - $rowWriter.initialize($bufferHolder, ${inputs.length}); + $resetWriter ${ctx.splitExpressions(row, writeFields)} """.trim } @@ -164,6 +174,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro input: String, elementType: DataType, bufferHolder: String): String = { + val arrayWriterClass = classOf[UnsafeArrayWriter].getName val arrayWriter = ctx.freshName("arrayWriter") ctx.addMutableState(arrayWriterClass, arrayWriter, s"this.$arrayWriter = new $arrayWriterClass();") @@ -288,22 +299,43 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination) val exprTypes = expressions.map(_.dataType) + val numVarLenFields = exprTypes.count { + case dt if UnsafeRow.isFixedLength(dt) => false + // TODO: consider large decimal and interval type + case _ => true + } + val result = ctx.freshName("result") ctx.addMutableState("UnsafeRow", result, s"$result = new UnsafeRow(${expressions.length});") - val bufferHolder = ctx.freshName("bufferHolder") + + val holder = ctx.freshName("holder") val holderClass = classOf[BufferHolder].getName - ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();") + ctx.addMutableState(holderClass, holder, + s"this.$holder = new $holderClass($result, ${numVarLenFields * 32});") + + val resetBufferHolder = if (numVarLenFields == 0) { + "" + } else { + s"$holder.reset();" + } + val updateRowSize = if (numVarLenFields == 0) { + "" + } else { + s"$result.setTotalSize($holder.totalSize());" + } // Evaluate all the subexpression. val evalSubexpr = ctx.subexprFunctions.mkString("\n") + val writeExpressions = + writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, holder, isTopLevel = true) + val code = s""" - $bufferHolder.reset(); + $resetBufferHolder $evalSubexpr - ${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, bufferHolder)} - - $result.pointTo($bufferHolder.buffer, $bufferHolder.totalSize()); + $writeExpressions + $updateRowSize """ ExprCode(code, "false", result) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala new file mode 100644 index 0000000000..a6d9040938 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.types._ +import org.apache.spark.util.Benchmark + +/** + * Benchmark [[UnsafeProjection]] for fixed-length/primitive-type fields. + */ +object UnsafeProjectionBenchmark { + + def generateRows(schema: StructType, numRows: Int): Array[InternalRow] = { + val generator = RandomDataGenerator.forType(schema, nullable = false).get + val encoder = RowEncoder(schema) + (1 to numRows).map(_ => encoder.toRow(generator().asInstanceOf[Row]).copy()).toArray + } + + def main(args: Array[String]) { + val iters = 1024 * 16 + val numRows = 1024 * 16 + + val benchmark = new Benchmark("unsafe projection", iters * numRows) + + + val schema1 = new StructType().add("l", LongType, false) + val attrs1 = schema1.toAttributes + val rows1 = generateRows(schema1, numRows) + val projection1 = UnsafeProjection.create(attrs1, attrs1) + + benchmark.addCase("single long") { _ => + for (_ <- 1 to iters) { + var sum = 0L + var i = 0 + while (i < numRows) { + sum += projection1(rows1(i)).getLong(0) + i += 1 + } + } + } + + val schema2 = new StructType().add("l", LongType, true) + val attrs2 = schema2.toAttributes + val rows2 = generateRows(schema2, numRows) + val projection2 = UnsafeProjection.create(attrs2, attrs2) + + benchmark.addCase("single nullable long") { _ => + for (_ <- 1 to iters) { + var sum = 0L + var i = 0 + while (i < numRows) { + sum += projection2(rows2(i)).getLong(0) + i += 1 + } + } + } + + + val schema3 = new StructType() + .add("boolean", BooleanType, false) + .add("byte", ByteType, false) + .add("short", ShortType, false) + .add("int", IntegerType, false) + .add("long", LongType, false) + .add("float", FloatType, false) + .add("double", DoubleType, false) + val attrs3 = schema3.toAttributes + val rows3 = generateRows(schema3, numRows) + val projection3 = UnsafeProjection.create(attrs3, attrs3) + + benchmark.addCase("7 primitive types") { _ => + for (_ <- 1 to iters) { + var sum = 0L + var i = 0 + while (i < numRows) { + sum += projection3(rows3(i)).getLong(0) + i += 1 + } + } + } + + + val schema4 = new StructType() + .add("boolean", BooleanType, true) + .add("byte", ByteType, true) + .add("short", ShortType, true) + .add("int", IntegerType, true) + .add("long", LongType, true) + .add("float", FloatType, true) + .add("double", DoubleType, true) + val attrs4 = schema4.toAttributes + val rows4 = generateRows(schema4, numRows) + val projection4 = UnsafeProjection.create(attrs4, attrs4) + + benchmark.addCase("7 nullable primitive types") { _ => + for (_ <- 1 to iters) { + var sum = 0L + var i = 0 + while (i < numRows) { + sum += projection4(rows4(i)).getLong(0) + i += 1 + } + } + } + + + /* + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + unsafe projection: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + single long 1533.34 175.07 1.00 X + single nullable long 2306.73 116.37 0.66 X + primitive types 8403.93 31.94 0.18 X + nullable primitive types 12448.39 21.56 0.12 X + */ + benchmark.run() + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index 80805f15a8..17adfec321 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -74,11 +74,6 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas private boolean containsVarLenFields; /** - * The number of bytes in the fixed length portion of the row. - */ - private int fixedSizeBytes; - - /** * For each request column, the reader to read this column. * columnsReaders[i] populated the UnsafeRow's attribute at i. */ @@ -266,19 +261,13 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas /** * Initialize rows and rowWriters. These objects are reused across all rows in the relation. */ - int rowByteSize = UnsafeRow.calculateBitSetWidthInBytes(requestedSchema.getFieldCount()); - rowByteSize += 8 * requestedSchema.getFieldCount(); - fixedSizeBytes = rowByteSize; - rowByteSize += numVarLenFields * DEFAULT_VAR_LEN_SIZE; containsVarLenFields = numVarLenFields > 0; rowWriters = new UnsafeRowWriter[rows.length]; for (int i = 0; i < rows.length; ++i) { rows[i] = new UnsafeRow(requestedSchema.getFieldCount()); - rowWriters[i] = new UnsafeRowWriter(); - BufferHolder holder = new BufferHolder(rowByteSize); - rowWriters[i].initialize(rows[i], holder, requestedSchema.getFieldCount()); - rows[i].pointTo(holder.buffer, Platform.BYTE_ARRAY_OFFSET, holder.buffer.length); + BufferHolder holder = new BufferHolder(rows[i], numVarLenFields * DEFAULT_VAR_LEN_SIZE); + rowWriters[i] = new UnsafeRowWriter(holder, requestedSchema.getFieldCount()); } } @@ -295,7 +284,7 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas if (containsVarLenFields) { for (int i = 0; i < rowWriters.length; ++i) { - rowWriters[i].holder().resetTo(fixedSizeBytes); + rowWriters[i].holder().reset(); } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 72eb1f6cf0..738b9a35d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -132,8 +132,8 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera private ByteOrder nativeOrder = null; private byte[][] buffers = null; private UnsafeRow unsafeRow = new UnsafeRow($numFields); - private BufferHolder bufferHolder = new BufferHolder(); - private UnsafeRowWriter rowWriter = new UnsafeRowWriter(); + private BufferHolder bufferHolder = new BufferHolder(unsafeRow); + private UnsafeRowWriter rowWriter = new UnsafeRowWriter(bufferHolder, $numFields); private MutableUnsafeRow mutableRow = null; private int currentRow = 0; @@ -181,9 +181,9 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera public InternalRow next() { currentRow += 1; bufferHolder.reset(); - rowWriter.initialize(bufferHolder, $numFields); + rowWriter.zeroOutNullBytes(); ${extractors.mkString("\n")} - unsafeRow.pointTo(bufferHolder.buffer, bufferHolder.totalSize()); + unsafeRow.setTotalSize(bufferHolder.totalSize()); return unsafeRow; } }""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index bd2d17c018..430257f60d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -98,16 +98,15 @@ private[sql] class TextRelation( sqlContext.sparkContext.hadoopRDD( conf.asInstanceOf[JobConf], classOf[TextInputFormat], classOf[LongWritable], classOf[Text]) .mapPartitions { iter => - val bufferHolder = new BufferHolder - val unsafeRowWriter = new UnsafeRowWriter val unsafeRow = new UnsafeRow(1) + val bufferHolder = new BufferHolder(unsafeRow) + val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1) iter.map { case (_, line) => // Writes to an UnsafeRow directly bufferHolder.reset() - unsafeRowWriter.initialize(bufferHolder, 1) unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) - unsafeRow.pointTo(bufferHolder.buffer, bufferHolder.totalSize()) + unsafeRow.setTotalSize(bufferHolder.totalSize()) unsafeRow } } |