aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java174
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java59
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java180
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java13
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java4
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java39
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java4
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala86
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala33
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala22
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala3
16 files changed, 549 insertions, 90 deletions
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
index b5dddb9f11..4576ac2a32 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
@@ -37,6 +37,7 @@ import org.apache.parquet.schema.PrimitiveType;
import org.apache.parquet.schema.Type;
import org.apache.spark.memory.MemoryMode;
+import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder;
import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter;
@@ -44,7 +45,7 @@ import org.apache.spark.sql.execution.vectorized.ColumnVector;
import org.apache.spark.sql.execution.vectorized.ColumnarBatch;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Decimal;
-import org.apache.spark.unsafe.Platform;
+import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.unsafe.types.UTF8String;
import static org.apache.parquet.column.ValuesType.*;
@@ -57,7 +58,7 @@ import static org.apache.parquet.column.ValuesType.*;
* TODO: handle complex types, decimal requiring more than 8 bytes, INT96. Schema mismatch.
* All of these can be handled efficiently and easily with codegen.
*/
-public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBase<UnsafeRow> {
+public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBase<InternalRow> {
/**
* Batch of unsafe rows that we assemble and the current index we've returned. Everytime this
* batch is used up (batchIdx == numBatched), we populated the batch.
@@ -110,6 +111,9 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
* and currently unsupported cases will fail with potentially difficult to diagnose errors.
* This should be only turned on for development to work on this feature.
*
+ * When this is set, the code will branch early on in the RecordReader APIs. There is no shared
+ * code between the path that uses the MR decoders and the vectorized ones.
+ *
* TODOs:
* - Implement all the encodings to support vectorized.
* - Implement v2 page formats (just make sure we create the correct decoders).
@@ -166,15 +170,23 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
@Override
public boolean nextKeyValue() throws IOException, InterruptedException {
if (batchIdx >= numBatched) {
- if (!loadBatch()) return false;
+ if (vectorizedDecode()) {
+ if (!nextBatch()) return false;
+ } else {
+ if (!loadBatch()) return false;
+ }
}
++batchIdx;
return true;
}
@Override
- public UnsafeRow getCurrentValue() throws IOException, InterruptedException {
- return rows[batchIdx - 1];
+ public InternalRow getCurrentValue() throws IOException, InterruptedException {
+ if (vectorizedDecode()) {
+ return columnarBatch.getRow(batchIdx - 1);
+ } else {
+ return rows[batchIdx - 1];
+ }
}
@Override
@@ -202,20 +214,27 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
* Advances to the next batch of rows. Returns false if there are no more.
*/
public boolean nextBatch() throws IOException {
- assert(columnarBatch != null);
+ assert(vectorizedDecode());
columnarBatch.reset();
if (rowsReturned >= totalRowCount) return false;
checkEndOfRowGroup();
- int num = (int)Math.min((long) columnarBatch.capacity(), totalRowCount - rowsReturned);
+ int num = (int)Math.min((long) columnarBatch.capacity(), totalCountLoadedSoFar - rowsReturned);
for (int i = 0; i < columnReaders.length; ++i) {
columnReaders[i].readBatch(num, columnarBatch.column(i));
}
rowsReturned += num;
columnarBatch.setNumRows(num);
+ numBatched = num;
+ batchIdx = 0;
return true;
}
+ /**
+ * Returns true if we are doing a vectorized decode.
+ */
+ private boolean vectorizedDecode() { return columnarBatch != null; }
+
private void initializeInternal() throws IOException {
/**
* Check that the requested schema is supported.
@@ -613,15 +632,27 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
decodeDictionaryIds(rowId, num, column);
} else {
switch (descriptor.getType()) {
+ case BOOLEAN:
+ readBooleanBatch(rowId, num, column);
+ break;
case INT32:
readIntBatch(rowId, num, column);
break;
case INT64:
readLongBatch(rowId, num, column);
break;
+ case FLOAT:
+ readFloatBatch(rowId, num, column);
+ break;
+ case DOUBLE:
+ readDoubleBatch(rowId, num, column);
+ break;
case BINARY:
readBinaryBatch(rowId, num, column);
break;
+ case FIXED_LEN_BYTE_ARRAY:
+ readFixedLenByteArrayBatch(rowId, num, column, descriptor.getTypeLength());
+ break;
default:
throw new IOException("Unsupported type: " + descriptor.getType());
}
@@ -645,7 +676,15 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
}
} else if (column.dataType() == DataTypes.ByteType) {
for (int i = rowId; i < rowId + num; ++i) {
- column.putByte(i, (byte)dictionary.decodeToInt(dictionaryIds.getInt(i)));
+ column.putByte(i, (byte) dictionary.decodeToInt(dictionaryIds.getInt(i)));
+ }
+ } else if (column.dataType() == DataTypes.ShortType) {
+ for (int i = rowId; i < rowId + num; ++i) {
+ column.putShort(i, (short) dictionary.decodeToInt(dictionaryIds.getInt(i)));
+ }
+ } else if (DecimalType.is64BitDecimalType(column.dataType())) {
+ for (int i = rowId; i < rowId + num; ++i) {
+ column.putLong(i, dictionary.decodeToInt(dictionaryIds.getInt(i)));
}
} else {
throw new NotImplementedException("Unimplemented type: " + column.dataType());
@@ -653,8 +692,36 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
break;
case INT64:
+ if (column.dataType() == DataTypes.LongType ||
+ DecimalType.is64BitDecimalType(column.dataType())) {
+ for (int i = rowId; i < rowId + num; ++i) {
+ column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i)));
+ }
+ } else {
+ throw new NotImplementedException("Unimplemented type: " + column.dataType());
+ }
+ break;
+
+ case FLOAT:
+ for (int i = rowId; i < rowId + num; ++i) {
+ column.putFloat(i, dictionary.decodeToFloat(dictionaryIds.getInt(i)));
+ }
+ break;
+
+ case DOUBLE:
for (int i = rowId; i < rowId + num; ++i) {
- column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i)));
+ column.putDouble(i, dictionary.decodeToDouble(dictionaryIds.getInt(i)));
+ }
+ break;
+
+ case FIXED_LEN_BYTE_ARRAY:
+ if (DecimalType.is64BitDecimalType(column.dataType())) {
+ for (int i = rowId; i < rowId + num; ++i) {
+ Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i));
+ column.putLong(i, CatalystRowConverter.binaryToUnscaledLong(v));
+ }
+ } else {
+ throw new NotImplementedException();
}
break;
@@ -691,15 +758,24 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
* is guaranteed that num is smaller than the number of values left in the current page.
*/
+ private void readBooleanBatch(int rowId, int num, ColumnVector column) throws IOException {
+ assert(column.dataType() == DataTypes.BooleanType);
+ defColumn.readBooleans(
+ num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
+ }
+
private void readIntBatch(int rowId, int num, ColumnVector column) throws IOException {
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
- if (column.dataType() == DataTypes.IntegerType) {
+ if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType) {
defColumn.readIntegers(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn, 0);
} else if (column.dataType() == DataTypes.ByteType) {
defColumn.readBytes(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
+ } else if (DecimalType.is64BitDecimalType(column.dataType())) {
+ defColumn.readIntsAsLongs(
+ num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
} else {
throw new NotImplementedException("Unimplemented type: " + column.dataType());
}
@@ -707,11 +783,33 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
private void readLongBatch(int rowId, int num, ColumnVector column) throws IOException {
// This is where we implement support for the valid type conversions.
- // TODO: implement remaining type conversions
- if (column.dataType() == DataTypes.LongType) {
+ if (column.dataType() == DataTypes.LongType ||
+ DecimalType.is64BitDecimalType(column.dataType())) {
defColumn.readLongs(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
} else {
+ throw new UnsupportedOperationException("Unsupported conversion to: " + column.dataType());
+ }
+ }
+
+ private void readFloatBatch(int rowId, int num, ColumnVector column) throws IOException {
+ // This is where we implement support for the valid type conversions.
+ // TODO: support implicit cast to double?
+ if (column.dataType() == DataTypes.FloatType) {
+ defColumn.readFloats(
+ num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
+ } else {
+ throw new UnsupportedOperationException("Unsupported conversion to: " + column.dataType());
+ }
+ }
+
+ private void readDoubleBatch(int rowId, int num, ColumnVector column) throws IOException {
+ // This is where we implement support for the valid type conversions.
+ // TODO: implement remaining type conversions
+ if (column.dataType() == DataTypes.DoubleType) {
+ defColumn.readDoubles(
+ num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
+ } else {
throw new NotImplementedException("Unimplemented type: " + column.dataType());
}
}
@@ -727,6 +825,24 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
}
}
+ private void readFixedLenByteArrayBatch(int rowId, int num,
+ ColumnVector column, int arrayLen) throws IOException {
+ VectorizedValuesReader data = (VectorizedValuesReader) dataColumn;
+ // This is where we implement support for the valid type conversions.
+ // TODO: implement remaining type conversions
+ if (DecimalType.is64BitDecimalType(column.dataType())) {
+ for (int i = 0; i < num; i++) {
+ if (defColumn.readInteger() == maxDefLevel) {
+ column.putLong(rowId + i,
+ CatalystRowConverter.binaryToUnscaledLong(data.readBinary(arrayLen)));
+ } else {
+ column.putNull(rowId + i);
+ }
+ }
+ } else {
+ throw new NotImplementedException("Unimplemented type: " + column.dataType());
+ }
+ }
private void readPage() throws IOException {
DataPage page = pageReader.readPage();
@@ -763,7 +879,11 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
"could not read page in col " + descriptor +
" as the dictionary was missing for encoding " + dataEncoding);
}
- if (columnarBatch != null && dataEncoding == Encoding.PLAIN_DICTIONARY) {
+ if (vectorizedDecode()) {
+ if (dataEncoding != Encoding.PLAIN_DICTIONARY &&
+ dataEncoding != Encoding.RLE_DICTIONARY) {
+ throw new NotImplementedException("Unsupported encoding: " + dataEncoding);
+ }
this.dataColumn = new VectorizedRleValuesReader();
} else {
this.dataColumn = dataEncoding.getDictionaryBasedValuesReader(
@@ -771,8 +891,11 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
}
this.useDictionary = true;
} else {
- if (columnarBatch != null && dataEncoding == Encoding.PLAIN) {
- this.dataColumn = new VectorizedPlainValuesReader(4);
+ if (vectorizedDecode()) {
+ if (dataEncoding != Encoding.PLAIN) {
+ throw new NotImplementedException("Unsupported encoding: " + dataEncoding);
+ }
+ this.dataColumn = new VectorizedPlainValuesReader();
} else {
this.dataColumn = dataEncoding.getValuesReader(descriptor, VALUES);
}
@@ -791,10 +914,12 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
ValuesReader rlReader = page.getRlEncoding().getValuesReader(descriptor, REPETITION_LEVEL);
ValuesReader dlReader;
- // Initialize the decoders. Use custom ones if vectorized decoding is enabled.
- if (columnarBatch != null && page.getDlEncoding() == Encoding.RLE) {
+ // Initialize the decoders.
+ if (vectorizedDecode()) {
+ if (page.getDlEncoding() != Encoding.RLE && descriptor.getMaxDefinitionLevel() != 0) {
+ throw new NotImplementedException("Unsupported encoding: " + page.getDlEncoding());
+ }
int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel());
- assert(bitWidth != 0); // not implemented
this.defColumn = new VectorizedRleValuesReader(bitWidth);
dlReader = this.defColumn;
} else {
@@ -818,8 +943,17 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
this.pageValueCount = page.getValueCount();
this.repetitionLevelColumn = createRLEIterator(descriptor.getMaxRepetitionLevel(),
page.getRepetitionLevels(), descriptor);
- this.definitionLevelColumn = createRLEIterator(descriptor.getMaxDefinitionLevel(),
- page.getDefinitionLevels(), descriptor);
+
+ if (vectorizedDecode()) {
+ int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel());
+ this.defColumn = new VectorizedRleValuesReader(bitWidth);
+ this.definitionLevelColumn = new ValuesReaderIntIterator(this.defColumn);
+ this.defColumn.initFromBuffer(
+ this.pageValueCount, page.getDefinitionLevels().toByteArray());
+ } else {
+ this.definitionLevelColumn = createRLEIterator(descriptor.getMaxDefinitionLevel(),
+ page.getDefinitionLevels(), descriptor);
+ }
try {
initDataReader(page.getDataEncoding(), page.getData().toByteArray(), 0);
} catch (IOException e) {
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
index cec2418e46..bf3283e853 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
@@ -32,10 +32,9 @@ import org.apache.parquet.io.api.Binary;
public class VectorizedPlainValuesReader extends ValuesReader implements VectorizedValuesReader {
private byte[] buffer;
private int offset;
- private final int byteSize;
+ private int bitOffset; // Only used for booleans.
- public VectorizedPlainValuesReader(int byteSize) {
- this.byteSize = byteSize;
+ public VectorizedPlainValuesReader() {
}
@Override
@@ -46,12 +45,15 @@ public class VectorizedPlainValuesReader extends ValuesReader implements Vectori
@Override
public void skip() {
- offset += byteSize;
+ throw new UnsupportedOperationException();
}
@Override
- public void skip(int n) {
- offset += n * byteSize;
+ public final void readBooleans(int total, ColumnVector c, int rowId) {
+ // TODO: properly vectorize this
+ for (int i = 0; i < total; i++) {
+ c.putBoolean(rowId + i, readBoolean());
+ }
}
@Override
@@ -67,6 +69,18 @@ public class VectorizedPlainValuesReader extends ValuesReader implements Vectori
}
@Override
+ public final void readFloats(int total, ColumnVector c, int rowId) {
+ c.putFloats(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
+ offset += 4 * total;
+ }
+
+ @Override
+ public final void readDoubles(int total, ColumnVector c, int rowId) {
+ c.putDoubles(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
+ offset += 8 * total;
+ }
+
+ @Override
public final void readBytes(int total, ColumnVector c, int rowId) {
for (int i = 0; i < total; i++) {
// Bytes are stored as a 4-byte little endian int. Just read the first byte.
@@ -77,6 +91,18 @@ public class VectorizedPlainValuesReader extends ValuesReader implements Vectori
}
@Override
+ public final boolean readBoolean() {
+ byte b = Platform.getByte(buffer, offset);
+ boolean v = (b & (1 << bitOffset)) != 0;
+ bitOffset += 1;
+ if (bitOffset == 8) {
+ bitOffset = 0;
+ offset++;
+ }
+ return v;
+ }
+
+ @Override
public final int readInteger() {
int v = Platform.getInt(buffer, offset);
offset += 4;
@@ -96,6 +122,20 @@ public class VectorizedPlainValuesReader extends ValuesReader implements Vectori
}
@Override
+ public final float readFloat() {
+ float v = Platform.getFloat(buffer, offset);
+ offset += 4;
+ return v;
+ }
+
+ @Override
+ public final double readDouble() {
+ double v = Platform.getDouble(buffer, offset);
+ offset += 8;
+ return v;
+ }
+
+ @Override
public final void readBinary(int total, ColumnVector v, int rowId) {
for (int i = 0; i < total; i++) {
int len = readInteger();
@@ -104,4 +144,11 @@ public class VectorizedPlainValuesReader extends ValuesReader implements Vectori
v.putByteArray(rowId + i, buffer, start - Platform.BYTE_ARRAY_OFFSET, len);
}
}
+
+ @Override
+ public final Binary readBinary(int len) {
+ Binary result = Binary.fromByteArray(buffer, offset - Platform.BYTE_ARRAY_OFFSET, len);
+ offset += len;
+ return result;
+ }
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
index 9bfd74db38..629959a73b 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
@@ -87,13 +87,38 @@ public final class VectorizedRleValuesReader extends ValuesReader
this.offset = start;
this.in = page;
if (fixedWidth) {
- int length = readIntLittleEndian();
- this.end = this.offset + length;
+ if (bitWidth != 0) {
+ int length = readIntLittleEndian();
+ this.end = this.offset + length;
+ }
} else {
this.end = page.length;
if (this.end != this.offset) init(page[this.offset++] & 255);
}
- this.currentCount = 0;
+ if (bitWidth == 0) {
+ // 0 bit width, treat this as an RLE run of valueCount number of 0's.
+ this.mode = MODE.RLE;
+ this.currentCount = valueCount;
+ this.currentValue = 0;
+ } else {
+ this.currentCount = 0;
+ }
+ }
+
+ // Initialize the reader from a buffer. This is used for the V2 page encoding where the
+ // definition are in its own buffer.
+ public void initFromBuffer(int valueCount, byte[] data) {
+ this.offset = 0;
+ this.in = data;
+ this.end = data.length;
+ if (bitWidth == 0) {
+ // 0 bit width, treat this as an RLE run of valueCount number of 0's.
+ this.mode = MODE.RLE;
+ this.currentCount = valueCount;
+ this.currentValue = 0;
+ } else {
+ this.currentCount = 0;
+ }
}
/**
@@ -126,7 +151,6 @@ public final class VectorizedRleValuesReader extends ValuesReader
return readInteger();
}
-
@Override
public int readInteger() {
if (this.currentCount == 0) { this.readNextGroup(); }
@@ -189,6 +213,72 @@ public final class VectorizedRleValuesReader extends ValuesReader
}
// TODO: can this code duplication be removed without a perf penalty?
+ public void readBooleans(int total, ColumnVector c,
+ int rowId, int level, VectorizedValuesReader data) {
+ int left = total;
+ while (left > 0) {
+ if (this.currentCount == 0) this.readNextGroup();
+ int n = Math.min(left, this.currentCount);
+ switch (mode) {
+ case RLE:
+ if (currentValue == level) {
+ data.readBooleans(n, c, rowId);
+ c.putNotNulls(rowId, n);
+ } else {
+ c.putNulls(rowId, n);
+ }
+ break;
+ case PACKED:
+ for (int i = 0; i < n; ++i) {
+ if (currentBuffer[currentBufferIdx++] == level) {
+ c.putBoolean(rowId + i, data.readBoolean());
+ c.putNotNull(rowId + i);
+ } else {
+ c.putNull(rowId + i);
+ }
+ }
+ break;
+ }
+ rowId += n;
+ left -= n;
+ currentCount -= n;
+ }
+ }
+
+ public void readIntsAsLongs(int total, ColumnVector c,
+ int rowId, int level, VectorizedValuesReader data) {
+ int left = total;
+ while (left > 0) {
+ if (this.currentCount == 0) this.readNextGroup();
+ int n = Math.min(left, this.currentCount);
+ switch (mode) {
+ case RLE:
+ if (currentValue == level) {
+ for (int i = 0; i < n; i++) {
+ c.putLong(rowId + i, data.readInteger());
+ }
+ c.putNotNulls(rowId, n);
+ } else {
+ c.putNulls(rowId, n);
+ }
+ break;
+ case PACKED:
+ for (int i = 0; i < n; ++i) {
+ if (currentBuffer[currentBufferIdx++] == level) {
+ c.putLong(rowId + i, data.readInteger());
+ c.putNotNull(rowId + i);
+ } else {
+ c.putNull(rowId + i);
+ }
+ }
+ break;
+ }
+ rowId += n;
+ left -= n;
+ currentCount -= n;
+ }
+ }
+
public void readBytes(int total, ColumnVector c,
int rowId, int level, VectorizedValuesReader data) {
int left = total;
@@ -253,6 +343,70 @@ public final class VectorizedRleValuesReader extends ValuesReader
}
}
+ public void readFloats(int total, ColumnVector c, int rowId, int level,
+ VectorizedValuesReader data) {
+ int left = total;
+ while (left > 0) {
+ if (this.currentCount == 0) this.readNextGroup();
+ int n = Math.min(left, this.currentCount);
+ switch (mode) {
+ case RLE:
+ if (currentValue == level) {
+ data.readFloats(n, c, rowId);
+ c.putNotNulls(rowId, n);
+ } else {
+ c.putNulls(rowId, n);
+ }
+ break;
+ case PACKED:
+ for (int i = 0; i < n; ++i) {
+ if (currentBuffer[currentBufferIdx++] == level) {
+ c.putFloat(rowId + i, data.readFloat());
+ c.putNotNull(rowId + i);
+ } else {
+ c.putNull(rowId + i);
+ }
+ }
+ break;
+ }
+ rowId += n;
+ left -= n;
+ currentCount -= n;
+ }
+ }
+
+ public void readDoubles(int total, ColumnVector c, int rowId, int level,
+ VectorizedValuesReader data) {
+ int left = total;
+ while (left > 0) {
+ if (this.currentCount == 0) this.readNextGroup();
+ int n = Math.min(left, this.currentCount);
+ switch (mode) {
+ case RLE:
+ if (currentValue == level) {
+ data.readDoubles(n, c, rowId);
+ c.putNotNulls(rowId, n);
+ } else {
+ c.putNulls(rowId, n);
+ }
+ break;
+ case PACKED:
+ for (int i = 0; i < n; ++i) {
+ if (currentBuffer[currentBufferIdx++] == level) {
+ c.putDouble(rowId + i, data.readDouble());
+ c.putNotNull(rowId + i);
+ } else {
+ c.putNull(rowId + i);
+ }
+ }
+ break;
+ }
+ rowId += n;
+ left -= n;
+ currentCount -= n;
+ }
+ }
+
public void readBinarys(int total, ColumnVector c, int rowId, int level,
VectorizedValuesReader data) {
int left = total;
@@ -272,7 +426,7 @@ public final class VectorizedRleValuesReader extends ValuesReader
for (int i = 0; i < n; ++i) {
if (currentBuffer[currentBufferIdx++] == level) {
c.putNotNull(rowId + i);
- data.readBinary(1, c, rowId);
+ data.readBinary(1, c, rowId + i);
} else {
c.putNull(rowId + i);
}
@@ -331,10 +485,24 @@ public final class VectorizedRleValuesReader extends ValuesReader
}
@Override
- public void skip(int n) {
+ public void readBooleans(int total, ColumnVector c, int rowId) {
+ throw new UnsupportedOperationException("only readInts is valid.");
+ }
+
+ @Override
+ public void readFloats(int total, ColumnVector c, int rowId) {
+ throw new UnsupportedOperationException("only readInts is valid.");
+ }
+
+ @Override
+ public void readDoubles(int total, ColumnVector c, int rowId) {
throw new UnsupportedOperationException("only readInts is valid.");
}
+ @Override
+ public Binary readBinary(int len) {
+ throw new UnsupportedOperationException("only readInts is valid.");
+ }
/**
* Reads the next varint encoded int.
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java
index b6ec7311c5..88418ca53f 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java
@@ -19,24 +19,29 @@ package org.apache.spark.sql.execution.datasources.parquet;
import org.apache.spark.sql.execution.vectorized.ColumnVector;
+import org.apache.parquet.io.api.Binary;
+
/**
* Interface for value decoding that supports vectorized (aka batched) decoding.
* TODO: merge this into parquet-mr.
*/
public interface VectorizedValuesReader {
+ boolean readBoolean();
byte readByte();
int readInteger();
long readLong();
+ float readFloat();
+ double readDouble();
+ Binary readBinary(int len);
/*
* Reads `total` values into `c` start at `c[rowId]`
*/
+ void readBooleans(int total, ColumnVector c, int rowId);
void readBytes(int total, ColumnVector c, int rowId);
void readIntegers(int total, ColumnVector c, int rowId);
void readLongs(int total, ColumnVector c, int rowId);
+ void readFloats(int total, ColumnVector c, int rowId);
+ void readDoubles(int total, ColumnVector c, int rowId);
void readBinary(int total, ColumnVector c, int rowId);
-
- // TODO: add all the other parquet types.
-
- void skip(int n);
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
index 453bc15e13..2aeef7f2f9 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
@@ -18,11 +18,13 @@ package org.apache.spark.sql.execution.vectorized;
import java.math.BigDecimal;
import java.math.BigInteger;
+import java.sql.Date;
import java.util.Iterator;
import java.util.List;
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.util.DateTimeUtils;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.types.CalendarInterval;
@@ -100,6 +102,8 @@ public class ColumnVectorUtils {
dst.appendStruct(false);
dst.getChildColumn(0).appendInt(c.months);
dst.getChildColumn(1).appendLong(c.microseconds);
+ } else if (t instanceof DateType) {
+ dst.appendInt(DateTimeUtils.fromJavaDate((Date)o));
} else {
throw new NotImplementedException("Type " + t);
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
index dbad5e070f..070d897a71 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
@@ -23,11 +23,11 @@ import java.util.Iterator;
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.catalyst.util.MapData;
import org.apache.spark.sql.types.*;
-import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
@@ -62,6 +62,9 @@ public final class ColumnarBatch {
// Total number of rows that have been filtered.
private int numRowsFiltered = 0;
+ // Staging row returned from getRow.
+ final Row row;
+
public static ColumnarBatch allocate(StructType schema, MemoryMode memMode) {
return new ColumnarBatch(schema, DEFAULT_BATCH_SIZE, memMode);
}
@@ -123,24 +126,36 @@ public final class ColumnarBatch {
@Override
/**
- * Revisit this. This is expensive.
+ * Revisit this. This is expensive. This is currently only used in test paths.
*/
public final InternalRow copy() {
- UnsafeRow row = new UnsafeRow(numFields());
- row.pointTo(new byte[fixedLenRowSize], fixedLenRowSize);
+ GenericMutableRow row = new GenericMutableRow(columns.length);
for (int i = 0; i < numFields(); i++) {
if (isNullAt(i)) {
row.setNullAt(i);
} else {
DataType dt = columns[i].dataType();
- if (dt instanceof IntegerType) {
+ if (dt instanceof BooleanType) {
+ row.setBoolean(i, getBoolean(i));
+ } else if (dt instanceof IntegerType) {
row.setInt(i, getInt(i));
} else if (dt instanceof LongType) {
row.setLong(i, getLong(i));
+ } else if (dt instanceof FloatType) {
+ row.setFloat(i, getFloat(i));
} else if (dt instanceof DoubleType) {
row.setDouble(i, getDouble(i));
+ } else if (dt instanceof StringType) {
+ row.update(i, getUTF8String(i));
+ } else if (dt instanceof BinaryType) {
+ row.update(i, getBinary(i));
+ } else if (dt instanceof DecimalType) {
+ DecimalType t = (DecimalType)dt;
+ row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision());
+ } else if (dt instanceof DateType) {
+ row.setInt(i, getInt(i));
} else {
- throw new RuntimeException("Not implemented.");
+ throw new RuntimeException("Not implemented. " + dt);
}
}
}
@@ -316,6 +331,16 @@ public final class ColumnarBatch {
public ColumnVector column(int ordinal) { return columns[ordinal]; }
/**
+ * Returns the row in this batch at `rowId`. Returned row is reused across calls.
+ */
+ public ColumnarBatch.Row getRow(int rowId) {
+ assert(rowId >= 0);
+ assert(rowId < numRows);
+ row.rowId = rowId;
+ return row;
+ }
+
+ /**
* Marks this row as being filtered out. This means a subsequent iteration over the rows
* in this batch will not include this row.
*/
@@ -335,5 +360,7 @@ public final class ColumnarBatch {
StructField field = schema.fields()[i];
columns[i] = ColumnVector.allocate(maxRows, field.dataType(), memMode);
}
+
+ this.row = new Row(this);
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
index 7a224d19d1..c15f3d34a4 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
@@ -22,6 +22,7 @@ import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.types.BooleanType;
import org.apache.spark.sql.types.ByteType;
import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DateType;
import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.DoubleType;
import org.apache.spark.sql.types.FloatType;
@@ -391,7 +392,8 @@ public final class OffHeapColumnVector extends ColumnVector {
this.data = Platform.reallocateMemory(data, elementsAppended, newCapacity);
} else if (type instanceof ShortType) {
this.data = Platform.reallocateMemory(data, elementsAppended * 2, newCapacity * 2);
- } else if (type instanceof IntegerType || type instanceof FloatType) {
+ } else if (type instanceof IntegerType || type instanceof FloatType ||
+ type instanceof DateType) {
this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4);
} else if (type instanceof LongType || type instanceof DoubleType ||
DecimalType.is64BitDecimalType(type)) {
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
index c42bbd642e..99548bc83b 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
@@ -376,7 +376,7 @@ public final class OnHeapColumnVector extends ColumnVector {
short[] newData = new short[newCapacity];
if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended);
shortData = newData;
- } else if (type instanceof IntegerType) {
+ } else if (type instanceof IntegerType || type instanceof DateType) {
int[] newData = new int[newCapacity];
if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended);
intData = newData;
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index eb9da0bd4f..61a7b9935a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -345,6 +345,14 @@ private[spark] object SQLConf {
defaultValue = Some(true),
doc = "Enables using the custom ParquetUnsafeRowRecordReader.")
+ // Note: this can not be enabled all the time because the reader will not be returning UnsafeRows.
+ // Doing so is very expensive and we should remove this requirement instead of fixing it here.
+ // Initial testing seems to indicate only sort requires this.
+ val PARQUET_VECTORIZED_READER_ENABLED = booleanConf(
+ key = "spark.sql.parquet.enableVectorizedReader",
+ defaultValue = Some(false),
+ doc = "Enables vectorized parquet decoding.")
+
val ORC_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.orc.filterPushdown",
defaultValue = Some(false),
doc = "When true, enable filter pushdown for ORC files.")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
index 3605150b3b..25911334a6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
@@ -99,6 +99,8 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
// a subset of the types (no complex types).
protected val enableUnsafeRowParquetReader: Boolean =
sqlContext.getConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key).toBoolean
+ protected val enableVectorizedParquetReader: Boolean =
+ sqlContext.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean
override def getPartitions: Array[SparkPartition] = {
val conf = getConf(isDriverSide = true)
@@ -176,6 +178,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
parquetReader.close()
} else {
reader = parquetReader.asInstanceOf[RecordReader[Void, V]]
+ if (enableVectorizedParquetReader) parquetReader.resultBatch()
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala
index fb97a03df6..1c0d53fc77 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala
@@ -65,7 +65,8 @@ private[parquet] class CatalystSchemaConverter(
def this(conf: Configuration) = this(
assumeBinaryIsString = conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean,
assumeInt96IsTimestamp = conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean,
- writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean)
+ writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key,
+ SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get.toString).toBoolean)
/**
* Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]].
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
index ab48e971b5..bd87449f92 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
@@ -114,8 +114,10 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
val path = new Path(location.getCanonicalPath)
val conf = sparkContext.hadoopConfiguration
writeMetadata(parquetSchema, path, conf)
- val sparkTypes = sqlContext.read.parquet(path.toString).schema.map(_.dataType)
- assert(sparkTypes === expectedSparkTypes)
+ readParquetFile(path.toString)(df => {
+ val sparkTypes = df.schema.map(_.dataType)
+ assert(sparkTypes === expectedSparkTypes)
+ })
}
}
@@ -142,7 +144,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
withTempPath { dir =>
val data = makeDecimalRDD(DecimalType(precision, scale))
data.write.parquet(dir.getCanonicalPath)
- checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq)
+ readParquetFile(dir.getCanonicalPath){ df => {
+ checkAnswer(df, data.collect().toSeq)
+ }}
}
}
}
@@ -158,7 +162,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
withTempPath { dir =>
val data = makeDateRDD()
data.write.parquet(dir.getCanonicalPath)
- checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq)
+ readParquetFile(dir.getCanonicalPath) { df =>
+ checkAnswer(df, data.collect().toSeq)
+ }
}
}
@@ -335,9 +341,10 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "part-r-0.parquet")
makeRawParquetFile(path)
- checkAnswer(sqlContext.read.parquet(path.toString), (0 until 10).map { i =>
- Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble)
- })
+ readParquetFile(path.toString) { df =>
+ checkAnswer(df, (0 until 10).map { i =>
+ Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) })
+ }
}
}
@@ -363,7 +370,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
withParquetFile((1 to 10).map(i => (i, i.toString))) { file =>
val newData = (11 to 20).map(i => (i, i.toString))
newData.toDF().write.format("parquet").mode(SaveMode.Overwrite).save(file)
- checkAnswer(sqlContext.read.parquet(file), newData.map(Row.fromTuple))
+ readParquetFile(file) { df =>
+ checkAnswer(df, newData.map(Row.fromTuple))
+ }
}
}
@@ -372,7 +381,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
withParquetFile(data) { file =>
val newData = (11 to 20).map(i => (i, i.toString))
newData.toDF().write.format("parquet").mode(SaveMode.Ignore).save(file)
- checkAnswer(sqlContext.read.parquet(file), data.map(Row.fromTuple))
+ readParquetFile(file) { df =>
+ checkAnswer(df, data.map(Row.fromTuple))
+ }
}
}
@@ -392,7 +403,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
withParquetFile(data) { file =>
val newData = (11 to 20).map(i => (i, i.toString))
newData.toDF().write.format("parquet").mode(SaveMode.Append).save(file)
- checkAnswer(sqlContext.read.parquet(file), (data ++ newData).map(Row.fromTuple))
+ readParquetFile(file) { df =>
+ checkAnswer(df, (data ++ newData).map(Row.fromTuple))
+ }
}
}
@@ -420,11 +433,13 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
val conf = sparkContext.hadoopConfiguration
writeMetadata(parquetSchema, path, conf, extraMetadata)
- assertResult(sqlContext.read.parquet(path.toString).schema) {
- StructType(
- StructField("a", BooleanType, nullable = false) ::
- StructField("b", IntegerType, nullable = false) ::
- Nil)
+ readParquetFile(path.toString) { df =>
+ assertResult(df.schema) {
+ StructType(
+ StructField("a", BooleanType, nullable = false) ::
+ StructField("b", IntegerType, nullable = false) ::
+ Nil)
+ }
}
}
}
@@ -594,30 +609,43 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
val path = s"${dir.getCanonicalPath}/data"
df.write.parquet(path)
- val df2 = sqlContext.read.parquet(path)
- assert(df2.agg("col" -> "count").collect().head.getLong(0) == 50)
+ readParquetFile(path) { df2 =>
+ assert(df2.agg("col" -> "count").collect().head.getLong(0) == 50)
+ }
}
}
test("read dictionary encoded decimals written as INT32") {
- checkAnswer(
- // Decimal column in this file is encoded using plain dictionary
- readResourceParquetFile("dec-in-i32.parquet"),
- sqlContext.range(1 << 4).select('id % 10 cast DecimalType(5, 2) as 'i32_dec))
+ ("true" :: "false" :: Nil).foreach { vectorized =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) {
+ checkAnswer(
+ // Decimal column in this file is encoded using plain dictionary
+ readResourceParquetFile("dec-in-i32.parquet"),
+ sqlContext.range(1 << 4).select('id % 10 cast DecimalType(5, 2) as 'i32_dec))
+ }
+ }
}
test("read dictionary encoded decimals written as INT64") {
- checkAnswer(
- // Decimal column in this file is encoded using plain dictionary
- readResourceParquetFile("dec-in-i64.parquet"),
- sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'i64_dec))
+ ("true" :: "false" :: Nil).foreach { vectorized =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) {
+ checkAnswer(
+ // Decimal column in this file is encoded using plain dictionary
+ readResourceParquetFile("dec-in-i64.parquet"),
+ sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'i64_dec))
+ }
+ }
}
test("read dictionary encoded decimals written as FIXED_LEN_BYTE_ARRAY") {
- checkAnswer(
- // Decimal column in this file is encoded using plain dictionary
- readResourceParquetFile("dec-in-fixed-len.parquet"),
- sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'fixed_len_dec))
+ ("true" :: "false" :: Nil).foreach { vectorized =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) {
+ checkAnswer(
+ // Decimal column in this file is encoded using plain dictionary
+ readResourceParquetFile("dec-in-fixed-len.parquet"),
+ sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'fixed_len_dec))
+ }
+ }
}
test("SPARK-12589 copy() on rows returned from reader works for strings") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
index 0bc64404f1..b123d2b31e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
@@ -45,7 +45,8 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
test("appending") {
val data = (0 until 10).map(i => (i, i.toString))
sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
- withParquetTable(data, "t") {
+ // Query appends, don't test with both read modes.
+ withParquetTable(data, "t", false) {
sql("INSERT INTO TABLE t SELECT * FROM tmp")
checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple))
}
@@ -69,7 +70,8 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
(maybeInt, i.toString)
}
- withParquetTable(data, "t") {
+ // TODO: vectorized doesn't work here because it requires UnsafeRows
+ withParquetTable(data, "t", false) {
val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1")
val queryOutput = selfJoin.queryExecution.analyzed.output
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala
index 14be9eec9a..e8893073e3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala
@@ -81,6 +81,12 @@ object ParquetReadBenchmark {
}
}
+ sqlBenchmark.addCase("SQL Parquet Vectorized") { iter =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") {
+ sqlContext.sql("select sum(id) from tempTable").collect()
+ }
+ }
+
val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray
// Driving the parquet reader directly without Spark.
parquetReaderBenchmark.addCase("ParquetReader") { num =>
@@ -143,10 +149,11 @@ object ParquetReadBenchmark {
/*
Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz
- Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate
+ SQL Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate
-------------------------------------------------------------------------------
- SQL Parquet Reader 1682.6 15.58 1.00 X
- SQL Parquet MR 2379.6 11.02 0.71 X
+ SQL Parquet Reader 1350.56 11.65 1.00 X
+ SQL Parquet MR 1844.09 8.53 0.73 X
+ SQL Parquet Vectorized 1062.04 14.81 1.27 X
*/
sqlBenchmark.run()
@@ -185,6 +192,13 @@ object ParquetReadBenchmark {
}
}
+ benchmark.addCase("SQL Parquet Vectorized") { iter =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") {
+ sqlContext.sql("select sum(c1), sum(length(c2)) from tempTable").collect
+ }
+ }
+
+
val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray
benchmark.addCase("ParquetReader") { num =>
var sum1 = 0L
@@ -202,12 +216,13 @@ object ParquetReadBenchmark {
}
/*
- Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz
- Int and String Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate
- -------------------------------------------------------------------------
- SQL Parquet Reader 2245.6 7.00 1.00 X
- SQL Parquet MR 2914.2 5.40 0.77 X
- ParquetReader 1544.6 10.18 1.45 X
+ Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz
+ Int and String Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate
+ -------------------------------------------------------------------------------
+ SQL Parquet Reader 1737.94 6.03 1.00 X
+ SQL Parquet MR 2393.08 4.38 0.73 X
+ SQL Parquet Vectorized 1442.99 7.27 1.20 X
+ ParquetReader 1032.11 10.16 1.68 X
*/
benchmark.run()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
index 449fcc860f..5cbcccbd86 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
@@ -44,6 +44,20 @@ import org.apache.spark.sql.types.StructType
private[sql] trait ParquetTest extends SQLTestUtils {
/**
+ * Reads the parquet file at `path`
+ */
+ protected def readParquetFile(path: String, testVectorized: Boolean = true)
+ (f: DataFrame => Unit) = {
+ (true :: false :: Nil).foreach { vectorized =>
+ if (!vectorized || testVectorized) {
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized.toString) {
+ f(sqlContext.read.parquet(path.toString))
+ }
+ }
+ }
+ }
+
+ /**
* Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f`
* returns.
*/
@@ -61,9 +75,9 @@ private[sql] trait ParquetTest extends SQLTestUtils {
* which is then passed to `f`. The Parquet file will be deleted after `f` returns.
*/
protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag]
- (data: Seq[T])
+ (data: Seq[T], testVectorized: Boolean = true)
(f: DataFrame => Unit): Unit = {
- withParquetFile(data)(path => f(sqlContext.read.parquet(path)))
+ withParquetFile(data)(path => readParquetFile(path.toString, testVectorized)(f))
}
/**
@@ -72,9 +86,9 @@ private[sql] trait ParquetTest extends SQLTestUtils {
* Parquet file will be dropped/deleted after `f` returns.
*/
protected def withParquetTable[T <: Product: ClassTag: TypeTag]
- (data: Seq[T], tableName: String)
+ (data: Seq[T], tableName: String, testVectorized: Boolean = true)
(f: => Unit): Unit = {
- withParquetDataFrame(data) { df =>
+ withParquetDataFrame(data, testVectorized) { df =>
sqlContext.registerDataFrameAsTable(df, tableName)
withTempTable(tableName)(f)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala
index 7841ffe5e0..b5af758a65 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala
@@ -61,7 +61,8 @@ class HiveParquetSuite extends QueryTest with ParquetTest with TestHiveSingleton
}
test("INSERT OVERWRITE TABLE Parquet table") {
- withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") {
+ // Don't run with vectorized: currently relies on UnsafeRow.
+ withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t", false) {
withTempPath { file =>
sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath)
hiveContext.read.parquet(file.getCanonicalPath).registerTempTable("p")