aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java4
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java88
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java16
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala4
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java10
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java24
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java32
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala11
16 files changed, 86 insertions, 137 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
index 3513960b41..3d80df2271 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
@@ -270,8 +270,8 @@ public class UnsafeArrayData extends ArrayData {
final int offset = getElementOffset(ordinal);
if (offset < 0) return null;
final int size = getElementSize(offset, ordinal);
- final UnsafeRow row = new UnsafeRow();
- row.pointTo(baseObject, baseOffset + offset, numFields, size);
+ final UnsafeRow row = new UnsafeRow(numFields);
+ row.pointTo(baseObject, baseOffset + offset, size);
return row;
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index b6979d0c82..7492b88c47 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -17,11 +17,7 @@
package org.apache.spark.sql.catalyst.expressions;
-import java.io.Externalizable;
-import java.io.IOException;
-import java.io.ObjectInput;
-import java.io.ObjectOutput;
-import java.io.OutputStream;
+import java.io.*;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.nio.ByteBuffer;
@@ -30,26 +26,12 @@ import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
-import org.apache.spark.sql.types.ArrayType;
-import org.apache.spark.sql.types.BinaryType;
-import org.apache.spark.sql.types.BooleanType;
-import org.apache.spark.sql.types.ByteType;
-import org.apache.spark.sql.types.CalendarIntervalType;
-import org.apache.spark.sql.types.DataType;
-import org.apache.spark.sql.types.DateType;
-import org.apache.spark.sql.types.Decimal;
-import org.apache.spark.sql.types.DecimalType;
-import org.apache.spark.sql.types.DoubleType;
-import org.apache.spark.sql.types.FloatType;
-import org.apache.spark.sql.types.IntegerType;
-import org.apache.spark.sql.types.LongType;
-import org.apache.spark.sql.types.MapType;
-import org.apache.spark.sql.types.NullType;
-import org.apache.spark.sql.types.ShortType;
-import org.apache.spark.sql.types.StringType;
-import org.apache.spark.sql.types.StructType;
-import org.apache.spark.sql.types.TimestampType;
-import org.apache.spark.sql.types.UserDefinedType;
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.KryoSerializable;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
+import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.bitset.BitSetMethods;
@@ -57,23 +39,9 @@ import org.apache.spark.unsafe.hash.Murmur3_x86_32;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
-import static org.apache.spark.sql.types.DataTypes.BooleanType;
-import static org.apache.spark.sql.types.DataTypes.ByteType;
-import static org.apache.spark.sql.types.DataTypes.DateType;
-import static org.apache.spark.sql.types.DataTypes.DoubleType;
-import static org.apache.spark.sql.types.DataTypes.FloatType;
-import static org.apache.spark.sql.types.DataTypes.IntegerType;
-import static org.apache.spark.sql.types.DataTypes.LongType;
-import static org.apache.spark.sql.types.DataTypes.NullType;
-import static org.apache.spark.sql.types.DataTypes.ShortType;
-import static org.apache.spark.sql.types.DataTypes.TimestampType;
+import static org.apache.spark.sql.types.DataTypes.*;
import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;
-import com.esotericsoftware.kryo.Kryo;
-import com.esotericsoftware.kryo.KryoSerializable;
-import com.esotericsoftware.kryo.io.Input;
-import com.esotericsoftware.kryo.io.Output;
-
/**
* An Unsafe implementation of Row which is backed by raw memory instead of Java objects.
*
@@ -167,8 +135,16 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
/**
* Construct a new UnsafeRow. The resulting row won't be usable until `pointTo()` has been called,
* since the value returned by this constructor is equivalent to a null pointer.
+ *
+ * @param numFields the number of fields in this row
*/
- public UnsafeRow() { }
+ public UnsafeRow(int numFields) {
+ this.numFields = numFields;
+ this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields);
+ }
+
+ // for serializer
+ public UnsafeRow() {}
public Object getBaseObject() { return baseObject; }
public long getBaseOffset() { return baseOffset; }
@@ -182,15 +158,12 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
*
* @param baseObject the base object
* @param baseOffset the offset within the base object
- * @param numFields the number of fields in this row
* @param sizeInBytes the size of this row's backing data, in bytes
*/
- public void pointTo(Object baseObject, long baseOffset, int numFields, int sizeInBytes) {
+ public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) {
assert numFields >= 0 : "numFields (" + numFields + ") should >= 0";
- this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields);
this.baseObject = baseObject;
this.baseOffset = baseOffset;
- this.numFields = numFields;
this.sizeInBytes = sizeInBytes;
}
@@ -198,23 +171,12 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
* Update this UnsafeRow to point to the underlying byte array.
*
* @param buf byte array to point to
- * @param numFields the number of fields in this row
- * @param sizeInBytes the number of bytes valid in the byte array
- */
- public void pointTo(byte[] buf, int numFields, int sizeInBytes) {
- pointTo(buf, Platform.BYTE_ARRAY_OFFSET, numFields, sizeInBytes);
- }
-
- /**
- * Updates this UnsafeRow preserving the number of fields.
- * @param buf byte array to point to
* @param sizeInBytes the number of bytes valid in the byte array
*/
public void pointTo(byte[] buf, int sizeInBytes) {
- pointTo(buf, numFields, sizeInBytes);
+ pointTo(buf, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
}
-
public void setNotNullAt(int i) {
assertIndexIsValid(i);
BitSetMethods.unset(baseObject, baseOffset, i);
@@ -489,8 +451,8 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
final long offsetAndSize = getLong(ordinal);
final int offset = (int) (offsetAndSize >> 32);
final int size = (int) offsetAndSize;
- final UnsafeRow row = new UnsafeRow();
- row.pointTo(baseObject, baseOffset + offset, numFields, size);
+ final UnsafeRow row = new UnsafeRow(numFields);
+ row.pointTo(baseObject, baseOffset + offset, size);
return row;
}
}
@@ -529,7 +491,7 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
*/
@Override
public UnsafeRow copy() {
- UnsafeRow rowCopy = new UnsafeRow();
+ UnsafeRow rowCopy = new UnsafeRow(numFields);
final byte[] rowDataCopy = new byte[sizeInBytes];
Platform.copyMemory(
baseObject,
@@ -538,7 +500,7 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
Platform.BYTE_ARRAY_OFFSET,
sizeInBytes
);
- rowCopy.pointTo(rowDataCopy, Platform.BYTE_ARRAY_OFFSET, numFields, sizeInBytes);
+ rowCopy.pointTo(rowDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
return rowCopy;
}
@@ -547,8 +509,8 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
* The returned row is invalid until we call copyFrom on it.
*/
public static UnsafeRow createFromByteArray(int numBytes, int numFields) {
- final UnsafeRow row = new UnsafeRow();
- row.pointTo(new byte[numBytes], numFields, numBytes);
+ final UnsafeRow row = new UnsafeRow(numFields);
+ row.pointTo(new byte[numBytes], numBytes);
return row;
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 352002b349..27ae62f121 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -26,10 +26,9 @@ import com.google.common.annotations.VisibleForTesting;
import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
-import org.apache.spark.sql.catalyst.util.AbstractScalaRowIterator;
import org.apache.spark.sql.catalyst.InternalRow;
-import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.sql.catalyst.util.AbstractScalaRowIterator;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.util.collection.unsafe.sort.PrefixComparator;
@@ -123,7 +122,7 @@ final class UnsafeExternalRowSorter {
return new AbstractScalaRowIterator<UnsafeRow>() {
private final int numFields = schema.length();
- private UnsafeRow row = new UnsafeRow();
+ private UnsafeRow row = new UnsafeRow(numFields);
@Override
public boolean hasNext() {
@@ -137,7 +136,6 @@ final class UnsafeExternalRowSorter {
row.pointTo(
sortedIterator.getBaseObject(),
sortedIterator.getBaseOffset(),
- numFields,
sortedIterator.getRecordLength());
if (!hasNext()) {
UnsafeRow copy = row.copy(); // so that we don't have dangling pointers to freed page
@@ -173,19 +171,21 @@ final class UnsafeExternalRowSorter {
private static final class RowComparator extends RecordComparator {
private final Ordering<InternalRow> ordering;
private final int numFields;
- private final UnsafeRow row1 = new UnsafeRow();
- private final UnsafeRow row2 = new UnsafeRow();
+ private final UnsafeRow row1;
+ private final UnsafeRow row2;
public RowComparator(Ordering<InternalRow> ordering, int numFields) {
this.numFields = numFields;
+ this.row1 = new UnsafeRow(numFields);
+ this.row2 = new UnsafeRow(numFields);
this.ordering = ordering;
}
@Override
public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) {
// TODO: Why are the sizes -1?
- row1.pointTo(baseObj1, baseOff1, numFields, -1);
- row2.pointTo(baseObj2, baseOff2, numFields, -1);
+ row1.pointTo(baseObj1, baseOff1, -1);
+ row2.pointTo(baseObj2, baseOff2, -1);
return ordering.compare(row1, row2);
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index c1defe12b0..d0e031f279 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -289,7 +289,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val exprTypes = expressions.map(_.dataType)
val result = ctx.freshName("result")
- ctx.addMutableState("UnsafeRow", result, s"this.$result = new UnsafeRow();")
+ ctx.addMutableState("UnsafeRow", result, s"$result = new UnsafeRow(${expressions.length});")
val bufferHolder = ctx.freshName("bufferHolder")
val holderClass = classOf[BufferHolder].getName
ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();")
@@ -303,7 +303,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
$subexprReset
${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, bufferHolder)}
- $result.pointTo($bufferHolder.buffer, ${expressions.length}, $bufferHolder.totalSize());
+ $result.pointTo($bufferHolder.buffer, $bufferHolder.totalSize());
"""
GeneratedExpressionCode(code, "false", result)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
index da602d9b4b..c9ff357bf3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
@@ -165,7 +165,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
|
|class SpecificUnsafeRowJoiner extends ${classOf[UnsafeRowJoiner].getName} {
| private byte[] buf = new byte[64];
- | private UnsafeRow out = new UnsafeRow();
+ | private UnsafeRow out = new UnsafeRow(${schema1.size + schema2.size});
|
| public UnsafeRow join(UnsafeRow row1, UnsafeRow row2) {
| // row1: ${schema1.size} fields, $bitset1Words words in bitset
@@ -188,7 +188,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
| $copyVariableLengthRow2
| $updateOffset
|
- | out.pointTo(buf, ${schema1.size + schema2.size}, sizeInBytes - $sizeReduction);
+ | out.pointTo(buf, sizeInBytes - $sizeReduction);
|
| return out;
| }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala
index 796d60032e..f8342214d9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala
@@ -90,13 +90,13 @@ class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite {
}
private def createUnsafeRow(numFields: Int): UnsafeRow = {
- val row = new UnsafeRow
+ val row = new UnsafeRow(numFields)
val sizeInBytes = numFields * 8 + ((numFields + 63) / 64) * 8
// Allocate a larger buffer than needed and point the UnsafeRow to somewhere in the middle.
// This way we can test the joiner when the input UnsafeRows are not the entire arrays.
val offset = numFields * 8
val buf = new Array[Byte](sizeInBytes + offset)
- row.pointTo(buf, Platform.BYTE_ARRAY_OFFSET + offset, numFields, sizeInBytes)
+ row.pointTo(buf, Platform.BYTE_ARRAY_OFFSET + offset, sizeInBytes)
row
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index a2f99d566d..6bf9d7bd03 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -61,7 +61,7 @@ public final class UnsafeFixedWidthAggregationMap {
/**
* Re-used pointer to the current aggregation buffer
*/
- private final UnsafeRow currentAggregationBuffer = new UnsafeRow();
+ private final UnsafeRow currentAggregationBuffer;
private final boolean enablePerfMetrics;
@@ -98,6 +98,7 @@ public final class UnsafeFixedWidthAggregationMap {
long pageSizeBytes,
boolean enablePerfMetrics) {
this.aggregationBufferSchema = aggregationBufferSchema;
+ this.currentAggregationBuffer = new UnsafeRow(aggregationBufferSchema.length());
this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema);
this.groupingKeySchema = groupingKeySchema;
this.map =
@@ -147,7 +148,6 @@ public final class UnsafeFixedWidthAggregationMap {
currentAggregationBuffer.pointTo(
address.getBaseObject(),
address.getBaseOffset(),
- aggregationBufferSchema.length(),
loc.getValueLength()
);
return currentAggregationBuffer;
@@ -165,8 +165,8 @@ public final class UnsafeFixedWidthAggregationMap {
private final BytesToBytesMap.MapIterator mapLocationIterator =
map.destructiveIterator();
- private final UnsafeRow key = new UnsafeRow();
- private final UnsafeRow value = new UnsafeRow();
+ private final UnsafeRow key = new UnsafeRow(groupingKeySchema.length());
+ private final UnsafeRow value = new UnsafeRow(aggregationBufferSchema.length());
@Override
public boolean next() {
@@ -177,13 +177,11 @@ public final class UnsafeFixedWidthAggregationMap {
key.pointTo(
keyAddress.getBaseObject(),
keyAddress.getBaseOffset(),
- groupingKeySchema.length(),
loc.getKeyLength()
);
value.pointTo(
valueAddress.getBaseObject(),
valueAddress.getBaseOffset(),
- aggregationBufferSchema.length(),
loc.getValueLength()
);
return true;
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index 8c9b9c85e3..0da26bf376 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -94,7 +94,7 @@ public final class UnsafeKVExternalSorter {
// The only new memory we are allocating is the pointer/prefix array.
BytesToBytesMap.MapIterator iter = map.iterator();
final int numKeyFields = keySchema.size();
- UnsafeRow row = new UnsafeRow();
+ UnsafeRow row = new UnsafeRow(numKeyFields);
while (iter.hasNext()) {
final BytesToBytesMap.Location loc = iter.next();
final Object baseObject = loc.getKeyAddress().getBaseObject();
@@ -107,7 +107,7 @@ public final class UnsafeKVExternalSorter {
long address = taskMemoryManager.encodePageNumberAndOffset(page, baseOffset - 8);
// Compute prefix
- row.pointTo(baseObject, baseOffset, numKeyFields, loc.getKeyLength());
+ row.pointTo(baseObject, baseOffset, loc.getKeyLength());
final long prefix = prefixComputer.computePrefix(row);
inMemSorter.insertRecord(address, prefix);
@@ -194,12 +194,14 @@ public final class UnsafeKVExternalSorter {
private static final class KVComparator extends RecordComparator {
private final BaseOrdering ordering;
- private final UnsafeRow row1 = new UnsafeRow();
- private final UnsafeRow row2 = new UnsafeRow();
+ private final UnsafeRow row1;
+ private final UnsafeRow row2;
private final int numKeyFields;
public KVComparator(BaseOrdering ordering, int numKeyFields) {
this.numKeyFields = numKeyFields;
+ this.row1 = new UnsafeRow(numKeyFields);
+ this.row2 = new UnsafeRow(numKeyFields);
this.ordering = ordering;
}
@@ -207,17 +209,15 @@ public final class UnsafeKVExternalSorter {
public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) {
// Note that since ordering doesn't need the total length of the record, we just pass -1
// into the row.
- row1.pointTo(baseObj1, baseOff1 + 4, numKeyFields, -1);
- row2.pointTo(baseObj2, baseOff2 + 4, numKeyFields, -1);
+ row1.pointTo(baseObj1, baseOff1 + 4, -1);
+ row2.pointTo(baseObj2, baseOff2 + 4, -1);
return ordering.compare(row1, row2);
}
}
public class KVSorterIterator extends KVIterator<UnsafeRow, UnsafeRow> {
- private UnsafeRow key = new UnsafeRow();
- private UnsafeRow value = new UnsafeRow();
- private final int numKeyFields = keySchema.size();
- private final int numValueFields = valueSchema.size();
+ private UnsafeRow key = new UnsafeRow(keySchema.size());
+ private UnsafeRow value = new UnsafeRow(valueSchema.size());
private final UnsafeSorterIterator underlying;
private KVSorterIterator(UnsafeSorterIterator underlying) {
@@ -237,8 +237,8 @@ public final class UnsafeKVExternalSorter {
// Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself)
int keyLen = Platform.getInt(baseObj, recordOffset);
int valueLen = recordLen - keyLen - 4;
- key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen);
- value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen);
+ key.pointTo(baseObj, recordOffset + 4, keyLen);
+ value.pointTo(baseObj, recordOffset + 4 + keyLen, valueLen);
return true;
} else {
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
index 0cc4566c9c..a6758bddfa 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
@@ -21,35 +21,28 @@ import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.List;
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
-import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder;
-import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter;
-import org.apache.spark.sql.types.Decimal;
-import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.types.UTF8String;
-
-import static org.apache.parquet.column.ValuesType.DEFINITION_LEVEL;
-import static org.apache.parquet.column.ValuesType.REPETITION_LEVEL;
-import static org.apache.parquet.column.ValuesType.VALUES;
-
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.parquet.Preconditions;
import org.apache.parquet.column.ColumnDescriptor;
import org.apache.parquet.column.Dictionary;
import org.apache.parquet.column.Encoding;
-import org.apache.parquet.column.page.DataPage;
-import org.apache.parquet.column.page.DataPageV1;
-import org.apache.parquet.column.page.DataPageV2;
-import org.apache.parquet.column.page.DictionaryPage;
-import org.apache.parquet.column.page.PageReadStore;
-import org.apache.parquet.column.page.PageReader;
+import org.apache.parquet.column.page.*;
import org.apache.parquet.column.values.ValuesReader;
import org.apache.parquet.io.api.Binary;
import org.apache.parquet.schema.OriginalType;
import org.apache.parquet.schema.PrimitiveType;
import org.apache.parquet.schema.Type;
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder;
+import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter;
+import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.types.UTF8String;
+
+import static org.apache.parquet.column.ValuesType.*;
+
/**
* A specialized RecordReader that reads into UnsafeRows directly using the Parquet column APIs.
*
@@ -181,12 +174,11 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
rowWriters = new UnsafeRowWriter[rows.length];
for (int i = 0; i < rows.length; ++i) {
- rows[i] = new UnsafeRow();
+ rows[i] = new UnsafeRow(requestedSchema.getFieldCount());
rowWriters[i] = new UnsafeRowWriter();
BufferHolder holder = new BufferHolder(rowByteSize);
rowWriters[i].initialize(rows[i], holder, requestedSchema.getFieldCount());
- rows[i].pointTo(holder.buffer, Platform.BYTE_ARRAY_OFFSET, requestedSchema.getFieldCount(),
- holder.buffer.length);
+ rows[i].pointTo(holder.buffer, Platform.BYTE_ARRAY_OFFSET, holder.buffer.length);
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
index 7e981268de..4730647c4b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
@@ -94,7 +94,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
private[this] val dIn: DataInputStream = new DataInputStream(new BufferedInputStream(in))
// 1024 is a default buffer size; this buffer will grow to accommodate larger rows
private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024)
- private[this] var row: UnsafeRow = new UnsafeRow()
+ private[this] var row: UnsafeRow = new UnsafeRow(numFields)
private[this] var rowTuple: (Int, UnsafeRow) = (0, row)
private[this] val EOF: Int = -1
@@ -117,7 +117,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
rowBuffer = new Array[Byte](rowSize)
}
ByteStreams.readFully(dIn, rowBuffer, 0, rowSize)
- row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize)
+ row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, rowSize)
rowSize = readSize()
if (rowSize == EOF) { // We are returning the last row in this stream
dIn.close()
@@ -152,7 +152,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
rowBuffer = new Array[Byte](rowSize)
}
ByteStreams.readFully(dIn, rowBuffer, 0, rowSize)
- row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize)
+ row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, rowSize)
row.asInstanceOf[T]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala
index c9f2329db4..9c908b2877 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala
@@ -574,11 +574,10 @@ private[columnar] case class STRUCT(dataType: StructType)
assert(buffer.hasArray)
val cursor = buffer.position()
buffer.position(cursor + sizeInBytes)
- val unsafeRow = new UnsafeRow
+ val unsafeRow = new UnsafeRow(numOfFields)
unsafeRow.pointTo(
buffer.array(),
Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor,
- numOfFields,
sizeInBytes)
unsafeRow
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
index eaafc96e4d..b208425ffc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
@@ -131,7 +131,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
private ByteOrder nativeOrder = null;
private byte[][] buffers = null;
- private UnsafeRow unsafeRow = new UnsafeRow();
+ private UnsafeRow unsafeRow = new UnsafeRow($numFields);
private BufferHolder bufferHolder = new BufferHolder();
private UnsafeRowWriter rowWriter = new UnsafeRowWriter();
private MutableUnsafeRow mutableRow = null;
@@ -183,7 +183,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
bufferHolder.reset();
rowWriter.initialize(bufferHolder, $numFields);
${extractors.mkString("\n")}
- unsafeRow.pointTo(bufferHolder.buffer, $numFields, bufferHolder.totalSize());
+ unsafeRow.pointTo(bufferHolder.buffer, bufferHolder.totalSize());
return unsafeRow;
}
}"""
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala
index 4a1cbe4c38..41fcb11d84 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala
@@ -101,14 +101,14 @@ private[sql] class TextRelation(
.mapPartitions { iter =>
val bufferHolder = new BufferHolder
val unsafeRowWriter = new UnsafeRowWriter
- val unsafeRow = new UnsafeRow
+ val unsafeRow = new UnsafeRow(1)
iter.map { case (_, line) =>
// Writes to an UnsafeRow directly
bufferHolder.reset()
unsafeRowWriter.initialize(bufferHolder, 1)
unsafeRowWriter.write(0, line.getBytes, 0, line.getLength)
- unsafeRow.pointTo(bufferHolder.buffer, 1, bufferHolder.totalSize())
+ unsafeRow.pointTo(bufferHolder.buffer, bufferHolder.totalSize())
unsafeRow
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
index fa2bc76721..81bfe4e67c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
@@ -56,15 +56,14 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField
// Create an iterator from sorter and wrapper it as Iterator[UnsafeRow]
def createIter(): Iterator[UnsafeRow] = {
val iter = sorter.getIterator
- val unsafeRow = new UnsafeRow
+ val unsafeRow = new UnsafeRow(numFieldsOfRight)
new Iterator[UnsafeRow] {
override def hasNext: Boolean = {
iter.hasNext
}
override def next(): UnsafeRow = {
iter.loadNext()
- unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, numFieldsOfRight,
- iter.getRecordLength)
+ unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength)
unsafeRow
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 8c7099ab5a..c6f56cfaed 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -245,8 +245,8 @@ private[joins] final class UnsafeHashedRelation(
val sizeInBytes = Platform.getInt(base, offset + 4)
offset += 8
- val row = new UnsafeRow
- row.pointTo(base, offset, numFields, sizeInBytes)
+ val row = new UnsafeRow(numFields)
+ row.pointTo(base, offset, sizeInBytes)
buffer += row
offset += sizeInBytes
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
index 00f1526576..a32763db05 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
@@ -34,8 +34,8 @@ class UnsafeRowSuite extends SparkFunSuite {
test("UnsafeRow Java serialization") {
// serializing an UnsafeRow pointing to a large buffer should only serialize the relevant data
val data = new Array[Byte](1024)
- val row = new UnsafeRow
- row.pointTo(data, 1, 16)
+ val row = new UnsafeRow(1)
+ row.pointTo(data, 16)
row.setLong(0, 19285)
val ser = new JavaSerializer(new SparkConf).newInstance()
@@ -47,8 +47,8 @@ class UnsafeRowSuite extends SparkFunSuite {
test("UnsafeRow Kryo serialization") {
// serializing an UnsafeRow pointing to a large buffer should only serialize the relevant data
val data = new Array[Byte](1024)
- val row = new UnsafeRow
- row.pointTo(data, 1, 16)
+ val row = new UnsafeRow(1)
+ row.pointTo(data, 16)
row.setLong(0, 19285)
val ser = new KryoSerializer(new SparkConf).newInstance()
@@ -86,11 +86,10 @@ class UnsafeRowSuite extends SparkFunSuite {
offheapRowPage.getBaseOffset,
arrayBackedUnsafeRow.getSizeInBytes
)
- val offheapUnsafeRow: UnsafeRow = new UnsafeRow()
+ val offheapUnsafeRow: UnsafeRow = new UnsafeRow(3)
offheapUnsafeRow.pointTo(
offheapRowPage.getBaseObject,
offheapRowPage.getBaseOffset,
- 3, // num fields
arrayBackedUnsafeRow.getSizeInBytes
)
assert(offheapUnsafeRow.getBaseObject === null)