aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNong Li <nong@databricks.com>2016-02-08 22:21:26 -0800
committerDavies Liu <davies.liu@gmail.com>2016-02-08 22:21:26 -0800
commit3708d13f1a282a9ebf12e3b736f1aa1712cbacd5 (patch)
treefe654891efa6dcc0c67612bfb437b441dffa1793
parenteeaf45b92695c577279f3a17d8c80ee40425e9aa (diff)
downloadspark-3708d13f1a282a9ebf12e3b736f1aa1712cbacd5.tar.gz
spark-3708d13f1a282a9ebf12e3b736f1aa1712cbacd5.tar.bz2
spark-3708d13f1a282a9ebf12e3b736f1aa1712cbacd5.zip
[SPARK-12992] [SQL] Support vectorized decoding in UnsafeRowParquetRecordReader.
WIP: running tests. Code needs a bit of clean up. This patch completes the vectorized decoding with the goal of passing the existing tests. There is still more patches to support the rest of the format spec, even just for flat schemas. This patch adds a new flag to enable the vectorized decoding. Tests were updated to try with both modes where applicable. Once this is working well, we can remove the previous code path. Author: Nong Li <nong@databricks.com> Closes #11055 from nongli/spark-12992-2.
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java174
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java59
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java180
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java13
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java4
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java39
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java4
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala86
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala33
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala22
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala3
16 files changed, 549 insertions, 90 deletions
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
index b5dddb9f11..4576ac2a32 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
@@ -37,6 +37,7 @@ import org.apache.parquet.schema.PrimitiveType;
import org.apache.parquet.schema.Type;
import org.apache.spark.memory.MemoryMode;
+import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder;
import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter;
@@ -44,7 +45,7 @@ import org.apache.spark.sql.execution.vectorized.ColumnVector;
import org.apache.spark.sql.execution.vectorized.ColumnarBatch;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Decimal;
-import org.apache.spark.unsafe.Platform;
+import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.unsafe.types.UTF8String;
import static org.apache.parquet.column.ValuesType.*;
@@ -57,7 +58,7 @@ import static org.apache.parquet.column.ValuesType.*;
* TODO: handle complex types, decimal requiring more than 8 bytes, INT96. Schema mismatch.
* All of these can be handled efficiently and easily with codegen.
*/
-public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBase<UnsafeRow> {
+public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBase<InternalRow> {
/**
* Batch of unsafe rows that we assemble and the current index we've returned. Everytime this
* batch is used up (batchIdx == numBatched), we populated the batch.
@@ -110,6 +111,9 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
* and currently unsupported cases will fail with potentially difficult to diagnose errors.
* This should be only turned on for development to work on this feature.
*
+ * When this is set, the code will branch early on in the RecordReader APIs. There is no shared
+ * code between the path that uses the MR decoders and the vectorized ones.
+ *
* TODOs:
* - Implement all the encodings to support vectorized.
* - Implement v2 page formats (just make sure we create the correct decoders).
@@ -166,15 +170,23 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
@Override
public boolean nextKeyValue() throws IOException, InterruptedException {
if (batchIdx >= numBatched) {
- if (!loadBatch()) return false;
+ if (vectorizedDecode()) {
+ if (!nextBatch()) return false;
+ } else {
+ if (!loadBatch()) return false;
+ }
}
++batchIdx;
return true;
}
@Override
- public UnsafeRow getCurrentValue() throws IOException, InterruptedException {
- return rows[batchIdx - 1];
+ public InternalRow getCurrentValue() throws IOException, InterruptedException {
+ if (vectorizedDecode()) {
+ return columnarBatch.getRow(batchIdx - 1);
+ } else {
+ return rows[batchIdx - 1];
+ }
}
@Override
@@ -202,20 +214,27 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
* Advances to the next batch of rows. Returns false if there are no more.
*/
public boolean nextBatch() throws IOException {
- assert(columnarBatch != null);
+ assert(vectorizedDecode());
columnarBatch.reset();
if (rowsReturned >= totalRowCount) return false;
checkEndOfRowGroup();
- int num = (int)Math.min((long) columnarBatch.capacity(), totalRowCount - rowsReturned);
+ int num = (int)Math.min((long) columnarBatch.capacity(), totalCountLoadedSoFar - rowsReturned);
for (int i = 0; i < columnReaders.length; ++i) {
columnReaders[i].readBatch(num, columnarBatch.column(i));
}
rowsReturned += num;
columnarBatch.setNumRows(num);
+ numBatched = num;
+ batchIdx = 0;
return true;
}
+ /**
+ * Returns true if we are doing a vectorized decode.
+ */
+ private boolean vectorizedDecode() { return columnarBatch != null; }
+
private void initializeInternal() throws IOException {
/**
* Check that the requested schema is supported.
@@ -613,15 +632,27 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
decodeDictionaryIds(rowId, num, column);
} else {
switch (descriptor.getType()) {
+ case BOOLEAN:
+ readBooleanBatch(rowId, num, column);
+ break;
case INT32:
readIntBatch(rowId, num, column);
break;
case INT64:
readLongBatch(rowId, num, column);
break;
+ case FLOAT:
+ readFloatBatch(rowId, num, column);
+ break;
+ case DOUBLE:
+ readDoubleBatch(rowId, num, column);
+ break;
case BINARY:
readBinaryBatch(rowId, num, column);
break;
+ case FIXED_LEN_BYTE_ARRAY:
+ readFixedLenByteArrayBatch(rowId, num, column, descriptor.getTypeLength());
+ break;
default:
throw new IOException("Unsupported type: " + descriptor.getType());
}
@@ -645,7 +676,15 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
}
} else if (column.dataType() == DataTypes.ByteType) {
for (int i = rowId; i < rowId + num; ++i) {
- column.putByte(i, (byte)dictionary.decodeToInt(dictionaryIds.getInt(i)));
+ column.putByte(i, (byte) dictionary.decodeToInt(dictionaryIds.getInt(i)));
+ }
+ } else if (column.dataType() == DataTypes.ShortType) {
+ for (int i = rowId; i < rowId + num; ++i) {
+ column.putShort(i, (short) dictionary.decodeToInt(dictionaryIds.getInt(i)));
+ }
+ } else if (DecimalType.is64BitDecimalType(column.dataType())) {
+ for (int i = rowId; i < rowId + num; ++i) {
+ column.putLong(i, dictionary.decodeToInt(dictionaryIds.getInt(i)));
}
} else {
throw new NotImplementedException("Unimplemented type: " + column.dataType());
@@ -653,8 +692,36 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
break;
case INT64:
+ if (column.dataType() == DataTypes.LongType ||
+ DecimalType.is64BitDecimalType(column.dataType())) {
+ for (int i = rowId; i < rowId + num; ++i) {
+ column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i)));
+ }
+ } else {
+ throw new NotImplementedException("Unimplemented type: " + column.dataType());
+ }
+ break;
+
+ case FLOAT:
+ for (int i = rowId; i < rowId + num; ++i) {
+ column.putFloat(i, dictionary.decodeToFloat(dictionaryIds.getInt(i)));
+ }
+ break;
+
+ case DOUBLE:
for (int i = rowId; i < rowId + num; ++i) {
- column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i)));
+ column.putDouble(i, dictionary.decodeToDouble(dictionaryIds.getInt(i)));
+ }
+ break;
+
+ case FIXED_LEN_BYTE_ARRAY:
+ if (DecimalType.is64BitDecimalType(column.dataType())) {
+ for (int i = rowId; i < rowId + num; ++i) {
+ Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i));
+ column.putLong(i, CatalystRowConverter.binaryToUnscaledLong(v));
+ }
+ } else {
+ throw new NotImplementedException();
}
break;
@@ -691,15 +758,24 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
* is guaranteed that num is smaller than the number of values left in the current page.
*/
+ private void readBooleanBatch(int rowId, int num, ColumnVector column) throws IOException {
+ assert(column.dataType() == DataTypes.BooleanType);
+ defColumn.readBooleans(
+ num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
+ }
+
private void readIntBatch(int rowId, int num, ColumnVector column) throws IOException {
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
- if (column.dataType() == DataTypes.IntegerType) {
+ if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType) {
defColumn.readIntegers(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn, 0);
} else if (column.dataType() == DataTypes.ByteType) {
defColumn.readBytes(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
+ } else if (DecimalType.is64BitDecimalType(column.dataType())) {
+ defColumn.readIntsAsLongs(
+ num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
} else {
throw new NotImplementedException("Unimplemented type: " + column.dataType());
}
@@ -707,11 +783,33 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
private void readLongBatch(int rowId, int num, ColumnVector column) throws IOException {
// This is where we implement support for the valid type conversions.
- // TODO: implement remaining type conversions
- if (column.dataType() == DataTypes.LongType) {
+ if (column.dataType() == DataTypes.LongType ||
+ DecimalType.is64BitDecimalType(column.dataType())) {
defColumn.readLongs(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
} else {
+ throw new UnsupportedOperationException("Unsupported conversion to: " + column.dataType());
+ }
+ }
+
+ private void readFloatBatch(int rowId, int num, ColumnVector column) throws IOException {
+ // This is where we implement support for the valid type conversions.
+ // TODO: support implicit cast to double?
+ if (column.dataType() == DataTypes.FloatType) {
+ defColumn.readFloats(
+ num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
+ } else {
+ throw new UnsupportedOperationException("Unsupported conversion to: " + column.dataType());
+ }
+ }
+
+ private void readDoubleBatch(int rowId, int num, ColumnVector column) throws IOException {
+ // This is where we implement support for the valid type conversions.
+ // TODO: implement remaining type conversions
+ if (column.dataType() == DataTypes.DoubleType) {
+ defColumn.readDoubles(
+ num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
+ } else {
throw new NotImplementedException("Unimplemented type: " + column.dataType());
}
}
@@ -727,6 +825,24 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
}
}
+ private void readFixedLenByteArrayBatch(int rowId, int num,
+ ColumnVector column, int arrayLen) throws IOException {
+ VectorizedValuesReader data = (VectorizedValuesReader) dataColumn;
+ // This is where we implement support for the valid type conversions.
+ // TODO: implement remaining type conversions
+ if (DecimalType.is64BitDecimalType(column.dataType())) {
+ for (int i = 0; i < num; i++) {
+ if (defColumn.readInteger() == maxDefLevel) {
+ column.putLong(rowId + i,
+ CatalystRowConverter.binaryToUnscaledLong(data.readBinary(arrayLen)));
+ } else {
+ column.putNull(rowId + i);
+ }
+ }
+ } else {
+ throw new NotImplementedException("Unimplemented type: " + column.dataType());
+ }
+ }
private void readPage() throws IOException {
DataPage page = pageReader.readPage();
@@ -763,7 +879,11 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
"could not read page in col " + descriptor +
" as the dictionary was missing for encoding " + dataEncoding);
}
- if (columnarBatch != null && dataEncoding == Encoding.PLAIN_DICTIONARY) {
+ if (vectorizedDecode()) {
+ if (dataEncoding != Encoding.PLAIN_DICTIONARY &&
+ dataEncoding != Encoding.RLE_DICTIONARY) {
+ throw new NotImplementedException("Unsupported encoding: " + dataEncoding);
+ }
this.dataColumn = new VectorizedRleValuesReader();
} else {
this.dataColumn = dataEncoding.getDictionaryBasedValuesReader(
@@ -771,8 +891,11 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
}
this.useDictionary = true;
} else {
- if (columnarBatch != null && dataEncoding == Encoding.PLAIN) {
- this.dataColumn = new VectorizedPlainValuesReader(4);
+ if (vectorizedDecode()) {
+ if (dataEncoding != Encoding.PLAIN) {
+ throw new NotImplementedException("Unsupported encoding: " + dataEncoding);
+ }
+ this.dataColumn = new VectorizedPlainValuesReader();
} else {
this.dataColumn = dataEncoding.getValuesReader(descriptor, VALUES);
}
@@ -791,10 +914,12 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
ValuesReader rlReader = page.getRlEncoding().getValuesReader(descriptor, REPETITION_LEVEL);
ValuesReader dlReader;
- // Initialize the decoders. Use custom ones if vectorized decoding is enabled.
- if (columnarBatch != null && page.getDlEncoding() == Encoding.RLE) {
+ // Initialize the decoders.
+ if (vectorizedDecode()) {
+ if (page.getDlEncoding() != Encoding.RLE && descriptor.getMaxDefinitionLevel() != 0) {
+ throw new NotImplementedException("Unsupported encoding: " + page.getDlEncoding());
+ }
int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel());
- assert(bitWidth != 0); // not implemented
this.defColumn = new VectorizedRleValuesReader(bitWidth);
dlReader = this.defColumn;
} else {
@@ -818,8 +943,17 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
this.pageValueCount = page.getValueCount();
this.repetitionLevelColumn = createRLEIterator(descriptor.getMaxRepetitionLevel(),
page.getRepetitionLevels(), descriptor);
- this.definitionLevelColumn = createRLEIterator(descriptor.getMaxDefinitionLevel(),
- page.getDefinitionLevels(), descriptor);
+
+ if (vectorizedDecode()) {
+ int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel());
+ this.defColumn = new VectorizedRleValuesReader(bitWidth);
+ this.definitionLevelColumn = new ValuesReaderIntIterator(this.defColumn);
+ this.defColumn.initFromBuffer(
+ this.pageValueCount, page.getDefinitionLevels().toByteArray());
+ } else {
+ this.definitionLevelColumn = createRLEIterator(descriptor.getMaxDefinitionLevel(),
+ page.getDefinitionLevels(), descriptor);
+ }
try {
initDataReader(page.getDataEncoding(), page.getData().toByteArray(), 0);
} catch (IOException e) {
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
index cec2418e46..bf3283e853 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
@@ -32,10 +32,9 @@ import org.apache.parquet.io.api.Binary;
public class VectorizedPlainValuesReader extends ValuesReader implements VectorizedValuesReader {
private byte[] buffer;
private int offset;
- private final int byteSize;
+ private int bitOffset; // Only used for booleans.
- public VectorizedPlainValuesReader(int byteSize) {
- this.byteSize = byteSize;
+ public VectorizedPlainValuesReader() {
}
@Override
@@ -46,12 +45,15 @@ public class VectorizedPlainValuesReader extends ValuesReader implements Vectori
@Override
public void skip() {
- offset += byteSize;
+ throw new UnsupportedOperationException();
}
@Override
- public void skip(int n) {
- offset += n * byteSize;
+ public final void readBooleans(int total, ColumnVector c, int rowId) {
+ // TODO: properly vectorize this
+ for (int i = 0; i < total; i++) {
+ c.putBoolean(rowId + i, readBoolean());
+ }
}
@Override
@@ -67,6 +69,18 @@ public class VectorizedPlainValuesReader extends ValuesReader implements Vectori
}
@Override
+ public final void readFloats(int total, ColumnVector c, int rowId) {
+ c.putFloats(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
+ offset += 4 * total;
+ }
+
+ @Override
+ public final void readDoubles(int total, ColumnVector c, int rowId) {
+ c.putDoubles(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
+ offset += 8 * total;
+ }
+
+ @Override
public final void readBytes(int total, ColumnVector c, int rowId) {
for (int i = 0; i < total; i++) {
// Bytes are stored as a 4-byte little endian int. Just read the first byte.
@@ -77,6 +91,18 @@ public class VectorizedPlainValuesReader extends ValuesReader implements Vectori
}
@Override
+ public final boolean readBoolean() {
+ byte b = Platform.getByte(buffer, offset);
+ boolean v = (b & (1 << bitOffset)) != 0;
+ bitOffset += 1;
+ if (bitOffset == 8) {
+ bitOffset = 0;
+ offset++;
+ }
+ return v;
+ }
+
+ @Override
public final int readInteger() {
int v = Platform.getInt(buffer, offset);
offset += 4;
@@ -96,6 +122,20 @@ public class VectorizedPlainValuesReader extends ValuesReader implements Vectori
}
@Override
+ public final float readFloat() {
+ float v = Platform.getFloat(buffer, offset);
+ offset += 4;
+ return v;
+ }
+
+ @Override
+ public final double readDouble() {
+ double v = Platform.getDouble(buffer, offset);
+ offset += 8;
+ return v;
+ }
+
+ @Override
public final void readBinary(int total, ColumnVector v, int rowId) {
for (int i = 0; i < total; i++) {
int len = readInteger();
@@ -104,4 +144,11 @@ public class VectorizedPlainValuesReader extends ValuesReader implements Vectori
v.putByteArray(rowId + i, buffer, start - Platform.BYTE_ARRAY_OFFSET, len);
}
}
+
+ @Override
+ public final Binary readBinary(int len) {
+ Binary result = Binary.fromByteArray(buffer, offset - Platform.BYTE_ARRAY_OFFSET, len);
+ offset += len;
+ return result;
+ }
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
index 9bfd74db38..629959a73b 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
@@ -87,13 +87,38 @@ public final class VectorizedRleValuesReader extends ValuesReader
this.offset = start;
this.in = page;
if (fixedWidth) {
- int length = readIntLittleEndian();
- this.end = this.offset + length;
+ if (bitWidth != 0) {
+ int length = readIntLittleEndian();
+ this.end = this.offset + length;
+ }
} else {
this.end = page.length;
if (this.end != this.offset) init(page[this.offset++] & 255);
}
- this.currentCount = 0;
+ if (bitWidth == 0) {
+ // 0 bit width, treat this as an RLE run of valueCount number of 0's.
+ this.mode = MODE.RLE;
+ this.currentCount = valueCount;
+ this.currentValue = 0;
+ } else {
+ this.currentCount = 0;
+ }
+ }
+
+ // Initialize the reader from a buffer. This is used for the V2 page encoding where the
+ // definition are in its own buffer.
+ public void initFromBuffer(int valueCount, byte[] data) {
+ this.offset = 0;
+ this.in = data;
+ this.end = data.length;
+ if (bitWidth == 0) {
+ // 0 bit width, treat this as an RLE run of valueCount number of 0's.
+ this.mode = MODE.RLE;
+ this.currentCount = valueCount;
+ this.currentValue = 0;
+ } else {
+ this.currentCount = 0;
+ }
}
/**
@@ -126,7 +151,6 @@ public final class VectorizedRleValuesReader extends ValuesReader
return readInteger();
}
-
@Override
public int readInteger() {
if (this.currentCount == 0) { this.readNextGroup(); }
@@ -189,6 +213,72 @@ public final class VectorizedRleValuesReader extends ValuesReader
}
// TODO: can this code duplication be removed without a perf penalty?
+ public void readBooleans(int total, ColumnVector c,
+ int rowId, int level, VectorizedValuesReader data) {
+ int left = total;
+ while (left > 0) {
+ if (this.currentCount == 0) this.readNextGroup();
+ int n = Math.min(left, this.currentCount);
+ switch (mode) {
+ case RLE:
+ if (currentValue == level) {
+ data.readBooleans(n, c, rowId);
+ c.putNotNulls(rowId, n);
+ } else {
+ c.putNulls(rowId, n);
+ }
+ break;
+ case PACKED:
+ for (int i = 0; i < n; ++i) {
+ if (currentBuffer[currentBufferIdx++] == level) {
+ c.putBoolean(rowId + i, data.readBoolean());
+ c.putNotNull(rowId + i);
+ } else {
+ c.putNull(rowId + i);
+ }
+ }
+ break;
+ }
+ rowId += n;
+ left -= n;
+ currentCount -= n;
+ }
+ }
+
+ public void readIntsAsLongs(int total, ColumnVector c,
+ int rowId, int level, VectorizedValuesReader data) {
+ int left = total;
+ while (left > 0) {
+ if (this.currentCount == 0) this.readNextGroup();
+ int n = Math.min(left, this.currentCount);
+ switch (mode) {
+ case RLE:
+ if (currentValue == level) {
+ for (int i = 0; i < n; i++) {
+ c.putLong(rowId + i, data.readInteger());
+ }
+ c.putNotNulls(rowId, n);
+ } else {
+ c.putNulls(rowId, n);
+ }
+ break;
+ case PACKED:
+ for (int i = 0; i < n; ++i) {
+ if (currentBuffer[currentBufferIdx++] == level) {
+ c.putLong(rowId + i, data.readInteger());
+ c.putNotNull(rowId + i);
+ } else {
+ c.putNull(rowId + i);
+ }
+ }
+ break;
+ }
+ rowId += n;
+ left -= n;
+ currentCount -= n;
+ }
+ }
+
public void readBytes(int total, ColumnVector c,
int rowId, int level, VectorizedValuesReader data) {
int left = total;
@@ -253,6 +343,70 @@ public final class VectorizedRleValuesReader extends ValuesReader
}
}
+ public void readFloats(int total, ColumnVector c, int rowId, int level,
+ VectorizedValuesReader data) {
+ int left = total;
+ while (left > 0) {
+ if (this.currentCount == 0) this.readNextGroup();
+ int n = Math.min(left, this.currentCount);
+ switch (mode) {
+ case RLE:
+ if (currentValue == level) {
+ data.readFloats(n, c, rowId);
+ c.putNotNulls(rowId, n);
+ } else {
+ c.putNulls(rowId, n);
+ }
+ break;
+ case PACKED:
+ for (int i = 0; i < n; ++i) {
+ if (currentBuffer[currentBufferIdx++] == level) {
+ c.putFloat(rowId + i, data.readFloat());
+ c.putNotNull(rowId + i);
+ } else {
+ c.putNull(rowId + i);
+ }
+ }
+ break;
+ }
+ rowId += n;
+ left -= n;
+ currentCount -= n;
+ }
+ }
+
+ public void readDoubles(int total, ColumnVector c, int rowId, int level,
+ VectorizedValuesReader data) {
+ int left = total;
+ while (left > 0) {
+ if (this.currentCount == 0) this.readNextGroup();
+ int n = Math.min(left, this.currentCount);
+ switch (mode) {
+ case RLE:
+ if (currentValue == level) {
+ data.readDoubles(n, c, rowId);
+ c.putNotNulls(rowId, n);
+ } else {
+ c.putNulls(rowId, n);
+ }
+ break;
+ case PACKED:
+ for (int i = 0; i < n; ++i) {
+ if (currentBuffer[currentBufferIdx++] == level) {
+ c.putDouble(rowId + i, data.readDouble());
+ c.putNotNull(rowId + i);
+ } else {
+ c.putNull(rowId + i);
+ }
+ }
+ break;
+ }
+ rowId += n;
+ left -= n;
+ currentCount -= n;
+ }
+ }
+
public void readBinarys(int total, ColumnVector c, int rowId, int level,
VectorizedValuesReader data) {
int left = total;
@@ -272,7 +426,7 @@ public final class VectorizedRleValuesReader extends ValuesReader
for (int i = 0; i < n; ++i) {
if (currentBuffer[currentBufferIdx++] == level) {
c.putNotNull(rowId + i);
- data.readBinary(1, c, rowId);
+ data.readBinary(1, c, rowId + i);
} else {
c.putNull(rowId + i);
}
@@ -331,10 +485,24 @@ public final class VectorizedRleValuesReader extends ValuesReader
}
@Override
- public void skip(int n) {
+ public void readBooleans(int total, ColumnVector c, int rowId) {
+ throw new UnsupportedOperationException("only readInts is valid.");
+ }
+
+ @Override
+ public void readFloats(int total, ColumnVector c, int rowId) {
+ throw new UnsupportedOperationException("only readInts is valid.");
+ }
+
+ @Override
+ public void readDoubles(int total, ColumnVector c, int rowId) {
throw new UnsupportedOperationException("only readInts is valid.");
}
+ @Override
+ public Binary readBinary(int len) {
+ throw new UnsupportedOperationException("only readInts is valid.");
+ }
/**
* Reads the next varint encoded int.
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java
index b6ec7311c5..88418ca53f 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java
@@ -19,24 +19,29 @@ package org.apache.spark.sql.execution.datasources.parquet;
import org.apache.spark.sql.execution.vectorized.ColumnVector;
+import org.apache.parquet.io.api.Binary;
+
/**
* Interface for value decoding that supports vectorized (aka batched) decoding.
* TODO: merge this into parquet-mr.
*/
public interface VectorizedValuesReader {
+ boolean readBoolean();
byte readByte();
int readInteger();
long readLong();
+ float readFloat();
+ double readDouble();
+ Binary readBinary(int len);
/*
* Reads `total` values into `c` start at `c[rowId]`
*/
+ void readBooleans(int total, ColumnVector c, int rowId);
void readBytes(int total, ColumnVector c, int rowId);
void readIntegers(int total, ColumnVector c, int rowId);
void readLongs(int total, ColumnVector c, int rowId);
+ void readFloats(int total, ColumnVector c, int rowId);
+ void readDoubles(int total, ColumnVector c, int rowId);
void readBinary(int total, ColumnVector c, int rowId);
-
- // TODO: add all the other parquet types.
-
- void skip(int n);
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
index 453bc15e13..2aeef7f2f9 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
@@ -18,11 +18,13 @@ package org.apache.spark.sql.execution.vectorized;
import java.math.BigDecimal;
import java.math.BigInteger;
+import java.sql.Date;
import java.util.Iterator;
import java.util.List;
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.util.DateTimeUtils;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.types.CalendarInterval;
@@ -100,6 +102,8 @@ public class ColumnVectorUtils {
dst.appendStruct(false);
dst.getChildColumn(0).appendInt(c.months);
dst.getChildColumn(1).appendLong(c.microseconds);
+ } else if (t instanceof DateType) {
+ dst.appendInt(DateTimeUtils.fromJavaDate((Date)o));
} else {
throw new NotImplementedException("Type " + t);
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
index dbad5e070f..070d897a71 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
@@ -23,11 +23,11 @@ import java.util.Iterator;
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.catalyst.util.MapData;
import org.apache.spark.sql.types.*;
-import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
@@ -62,6 +62,9 @@ public final class ColumnarBatch {
// Total number of rows that have been filtered.
private int numRowsFiltered = 0;
+ // Staging row returned from getRow.
+ final Row row;
+
public static ColumnarBatch allocate(StructType schema, MemoryMode memMode) {
return new ColumnarBatch(schema, DEFAULT_BATCH_SIZE, memMode);
}
@@ -123,24 +126,36 @@ public final class ColumnarBatch {
@Override
/**
- * Revisit this. This is expensive.
+ * Revisit this. This is expensive. This is currently only used in test paths.
*/
public final InternalRow copy() {
- UnsafeRow row = new UnsafeRow(numFields());
- row.pointTo(new byte[fixedLenRowSize], fixedLenRowSize);
+ GenericMutableRow row = new GenericMutableRow(columns.length);
for (int i = 0; i < numFields(); i++) {
if (isNullAt(i)) {
row.setNullAt(i);
} else {
DataType dt = columns[i].dataType();
- if (dt instanceof IntegerType) {
+ if (dt instanceof BooleanType) {
+ row.setBoolean(i, getBoolean(i));
+ } else if (dt instanceof IntegerType) {
row.setInt(i, getInt(i));
} else if (dt instanceof LongType) {
row.setLong(i, getLong(i));
+ } else if (dt instanceof FloatType) {
+ row.setFloat(i, getFloat(i));
} else if (dt instanceof DoubleType) {
row.setDouble(i, getDouble(i));
+ } else if (dt instanceof StringType) {
+ row.update(i, getUTF8String(i));
+ } else if (dt instanceof BinaryType) {
+ row.update(i, getBinary(i));
+ } else if (dt instanceof DecimalType) {
+ DecimalType t = (DecimalType)dt;
+ row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision());
+ } else if (dt instanceof DateType) {
+ row.setInt(i, getInt(i));
} else {
- throw new RuntimeException("Not implemented.");
+ throw new RuntimeException("Not implemented. " + dt);
}
}
}
@@ -316,6 +331,16 @@ public final class ColumnarBatch {
public ColumnVector column(int ordinal) { return columns[ordinal]; }
/**
+ * Returns the row in this batch at `rowId`. Returned row is reused across calls.
+ */
+ public ColumnarBatch.Row getRow(int rowId) {
+ assert(rowId >= 0);
+ assert(rowId < numRows);
+ row.rowId = rowId;
+ return row;
+ }
+
+ /**
* Marks this row as being filtered out. This means a subsequent iteration over the rows
* in this batch will not include this row.
*/
@@ -335,5 +360,7 @@ public final class ColumnarBatch {
StructField field = schema.fields()[i];
columns[i] = ColumnVector.allocate(maxRows, field.dataType(), memMode);
}
+
+ this.row = new Row(this);
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
index 7a224d19d1..c15f3d34a4 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
@@ -22,6 +22,7 @@ import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.types.BooleanType;
import org.apache.spark.sql.types.ByteType;
import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DateType;
import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.DoubleType;
import org.apache.spark.sql.types.FloatType;
@@ -391,7 +392,8 @@ public final class OffHeapColumnVector extends ColumnVector {
this.data = Platform.reallocateMemory(data, elementsAppended, newCapacity);
} else if (type instanceof ShortType) {
this.data = Platform.reallocateMemory(data, elementsAppended * 2, newCapacity * 2);
- } else if (type instanceof IntegerType || type instanceof FloatType) {
+ } else if (type instanceof IntegerType || type instanceof FloatType ||
+ type instanceof DateType) {
this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4);
} else if (type instanceof LongType || type instanceof DoubleType ||
DecimalType.is64BitDecimalType(type)) {
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
index c42bbd642e..99548bc83b 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
@@ -376,7 +376,7 @@ public final class OnHeapColumnVector extends ColumnVector {
short[] newData = new short[newCapacity];
if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended);
shortData = newData;
- } else if (type instanceof IntegerType) {
+ } else if (type instanceof IntegerType || type instanceof DateType) {
int[] newData = new int[newCapacity];
if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended);
intData = newData;
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index eb9da0bd4f..61a7b9935a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -345,6 +345,14 @@ private[spark] object SQLConf {
defaultValue = Some(true),
doc = "Enables using the custom ParquetUnsafeRowRecordReader.")
+ // Note: this can not be enabled all the time because the reader will not be returning UnsafeRows.
+ // Doing so is very expensive and we should remove this requirement instead of fixing it here.
+ // Initial testing seems to indicate only sort requires this.
+ val PARQUET_VECTORIZED_READER_ENABLED = booleanConf(
+ key = "spark.sql.parquet.enableVectorizedReader",
+ defaultValue = Some(false),
+ doc = "Enables vectorized parquet decoding.")
+
val ORC_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.orc.filterPushdown",
defaultValue = Some(false),
doc = "When true, enable filter pushdown for ORC files.")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
index 3605150b3b..25911334a6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
@@ -99,6 +99,8 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
// a subset of the types (no complex types).
protected val enableUnsafeRowParquetReader: Boolean =
sqlContext.getConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key).toBoolean
+ protected val enableVectorizedParquetReader: Boolean =
+ sqlContext.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean
override def getPartitions: Array[SparkPartition] = {
val conf = getConf(isDriverSide = true)
@@ -176,6 +178,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
parquetReader.close()
} else {
reader = parquetReader.asInstanceOf[RecordReader[Void, V]]
+ if (enableVectorizedParquetReader) parquetReader.resultBatch()
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala
index fb97a03df6..1c0d53fc77 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala
@@ -65,7 +65,8 @@ private[parquet] class CatalystSchemaConverter(
def this(conf: Configuration) = this(
assumeBinaryIsString = conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean,
assumeInt96IsTimestamp = conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean,
- writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean)
+ writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key,
+ SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get.toString).toBoolean)
/**
* Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]].
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
index ab48e971b5..bd87449f92 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
@@ -114,8 +114,10 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
val path = new Path(location.getCanonicalPath)
val conf = sparkContext.hadoopConfiguration
writeMetadata(parquetSchema, path, conf)
- val sparkTypes = sqlContext.read.parquet(path.toString).schema.map(_.dataType)
- assert(sparkTypes === expectedSparkTypes)
+ readParquetFile(path.toString)(df => {
+ val sparkTypes = df.schema.map(_.dataType)
+ assert(sparkTypes === expectedSparkTypes)
+ })
}
}
@@ -142,7 +144,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
withTempPath { dir =>
val data = makeDecimalRDD(DecimalType(precision, scale))
data.write.parquet(dir.getCanonicalPath)
- checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq)
+ readParquetFile(dir.getCanonicalPath){ df => {
+ checkAnswer(df, data.collect().toSeq)
+ }}
}
}
}
@@ -158,7 +162,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
withTempPath { dir =>
val data = makeDateRDD()
data.write.parquet(dir.getCanonicalPath)
- checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq)
+ readParquetFile(dir.getCanonicalPath) { df =>
+ checkAnswer(df, data.collect().toSeq)
+ }
}
}
@@ -335,9 +341,10 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "part-r-0.parquet")
makeRawParquetFile(path)
- checkAnswer(sqlContext.read.parquet(path.toString), (0 until 10).map { i =>
- Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble)
- })
+ readParquetFile(path.toString) { df =>
+ checkAnswer(df, (0 until 10).map { i =>
+ Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) })
+ }
}
}
@@ -363,7 +370,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
withParquetFile((1 to 10).map(i => (i, i.toString))) { file =>
val newData = (11 to 20).map(i => (i, i.toString))
newData.toDF().write.format("parquet").mode(SaveMode.Overwrite).save(file)
- checkAnswer(sqlContext.read.parquet(file), newData.map(Row.fromTuple))
+ readParquetFile(file) { df =>
+ checkAnswer(df, newData.map(Row.fromTuple))
+ }
}
}
@@ -372,7 +381,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
withParquetFile(data) { file =>
val newData = (11 to 20).map(i => (i, i.toString))
newData.toDF().write.format("parquet").mode(SaveMode.Ignore).save(file)
- checkAnswer(sqlContext.read.parquet(file), data.map(Row.fromTuple))
+ readParquetFile(file) { df =>
+ checkAnswer(df, data.map(Row.fromTuple))
+ }
}
}
@@ -392,7 +403,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
withParquetFile(data) { file =>
val newData = (11 to 20).map(i => (i, i.toString))
newData.toDF().write.format("parquet").mode(SaveMode.Append).save(file)
- checkAnswer(sqlContext.read.parquet(file), (data ++ newData).map(Row.fromTuple))
+ readParquetFile(file) { df =>
+ checkAnswer(df, (data ++ newData).map(Row.fromTuple))
+ }
}
}
@@ -420,11 +433,13 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
val conf = sparkContext.hadoopConfiguration
writeMetadata(parquetSchema, path, conf, extraMetadata)
- assertResult(sqlContext.read.parquet(path.toString).schema) {
- StructType(
- StructField("a", BooleanType, nullable = false) ::
- StructField("b", IntegerType, nullable = false) ::
- Nil)
+ readParquetFile(path.toString) { df =>
+ assertResult(df.schema) {
+ StructType(
+ StructField("a", BooleanType, nullable = false) ::
+ StructField("b", IntegerType, nullable = false) ::
+ Nil)
+ }
}
}
}
@@ -594,30 +609,43 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
val path = s"${dir.getCanonicalPath}/data"
df.write.parquet(path)
- val df2 = sqlContext.read.parquet(path)
- assert(df2.agg("col" -> "count").collect().head.getLong(0) == 50)
+ readParquetFile(path) { df2 =>
+ assert(df2.agg("col" -> "count").collect().head.getLong(0) == 50)
+ }
}
}
test("read dictionary encoded decimals written as INT32") {
- checkAnswer(
- // Decimal column in this file is encoded using plain dictionary
- readResourceParquetFile("dec-in-i32.parquet"),
- sqlContext.range(1 << 4).select('id % 10 cast DecimalType(5, 2) as 'i32_dec))
+ ("true" :: "false" :: Nil).foreach { vectorized =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) {
+ checkAnswer(
+ // Decimal column in this file is encoded using plain dictionary
+ readResourceParquetFile("dec-in-i32.parquet"),
+ sqlContext.range(1 << 4).select('id % 10 cast DecimalType(5, 2) as 'i32_dec))
+ }
+ }
}
test("read dictionary encoded decimals written as INT64") {
- checkAnswer(
- // Decimal column in this file is encoded using plain dictionary
- readResourceParquetFile("dec-in-i64.parquet"),
- sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'i64_dec))
+ ("true" :: "false" :: Nil).foreach { vectorized =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) {
+ checkAnswer(
+ // Decimal column in this file is encoded using plain dictionary
+ readResourceParquetFile("dec-in-i64.parquet"),
+ sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'i64_dec))
+ }
+ }
}
test("read dictionary encoded decimals written as FIXED_LEN_BYTE_ARRAY") {
- checkAnswer(
- // Decimal column in this file is encoded using plain dictionary
- readResourceParquetFile("dec-in-fixed-len.parquet"),
- sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'fixed_len_dec))
+ ("true" :: "false" :: Nil).foreach { vectorized =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) {
+ checkAnswer(
+ // Decimal column in this file is encoded using plain dictionary
+ readResourceParquetFile("dec-in-fixed-len.parquet"),
+ sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'fixed_len_dec))
+ }
+ }
}
test("SPARK-12589 copy() on rows returned from reader works for strings") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
index 0bc64404f1..b123d2b31e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
@@ -45,7 +45,8 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
test("appending") {
val data = (0 until 10).map(i => (i, i.toString))
sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
- withParquetTable(data, "t") {
+ // Query appends, don't test with both read modes.
+ withParquetTable(data, "t", false) {
sql("INSERT INTO TABLE t SELECT * FROM tmp")
checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple))
}
@@ -69,7 +70,8 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
(maybeInt, i.toString)
}
- withParquetTable(data, "t") {
+ // TODO: vectorized doesn't work here because it requires UnsafeRows
+ withParquetTable(data, "t", false) {
val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1")
val queryOutput = selfJoin.queryExecution.analyzed.output
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala
index 14be9eec9a..e8893073e3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala
@@ -81,6 +81,12 @@ object ParquetReadBenchmark {
}
}
+ sqlBenchmark.addCase("SQL Parquet Vectorized") { iter =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") {
+ sqlContext.sql("select sum(id) from tempTable").collect()
+ }
+ }
+
val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray
// Driving the parquet reader directly without Spark.
parquetReaderBenchmark.addCase("ParquetReader") { num =>
@@ -143,10 +149,11 @@ object ParquetReadBenchmark {
/*
Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz
- Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate
+ SQL Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate
-------------------------------------------------------------------------------
- SQL Parquet Reader 1682.6 15.58 1.00 X
- SQL Parquet MR 2379.6 11.02 0.71 X
+ SQL Parquet Reader 1350.56 11.65 1.00 X
+ SQL Parquet MR 1844.09 8.53 0.73 X
+ SQL Parquet Vectorized 1062.04 14.81 1.27 X
*/
sqlBenchmark.run()
@@ -185,6 +192,13 @@ object ParquetReadBenchmark {
}
}
+ benchmark.addCase("SQL Parquet Vectorized") { iter =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") {
+ sqlContext.sql("select sum(c1), sum(length(c2)) from tempTable").collect
+ }
+ }
+
+
val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray
benchmark.addCase("ParquetReader") { num =>
var sum1 = 0L
@@ -202,12 +216,13 @@ object ParquetReadBenchmark {
}
/*
- Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz
- Int and String Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate
- -------------------------------------------------------------------------
- SQL Parquet Reader 2245.6 7.00 1.00 X
- SQL Parquet MR 2914.2 5.40 0.77 X
- ParquetReader 1544.6 10.18 1.45 X
+ Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz
+ Int and String Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate
+ -------------------------------------------------------------------------------
+ SQL Parquet Reader 1737.94 6.03 1.00 X
+ SQL Parquet MR 2393.08 4.38 0.73 X
+ SQL Parquet Vectorized 1442.99 7.27 1.20 X
+ ParquetReader 1032.11 10.16 1.68 X
*/
benchmark.run()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
index 449fcc860f..5cbcccbd86 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
@@ -44,6 +44,20 @@ import org.apache.spark.sql.types.StructType
private[sql] trait ParquetTest extends SQLTestUtils {
/**
+ * Reads the parquet file at `path`
+ */
+ protected def readParquetFile(path: String, testVectorized: Boolean = true)
+ (f: DataFrame => Unit) = {
+ (true :: false :: Nil).foreach { vectorized =>
+ if (!vectorized || testVectorized) {
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized.toString) {
+ f(sqlContext.read.parquet(path.toString))
+ }
+ }
+ }
+ }
+
+ /**
* Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f`
* returns.
*/
@@ -61,9 +75,9 @@ private[sql] trait ParquetTest extends SQLTestUtils {
* which is then passed to `f`. The Parquet file will be deleted after `f` returns.
*/
protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag]
- (data: Seq[T])
+ (data: Seq[T], testVectorized: Boolean = true)
(f: DataFrame => Unit): Unit = {
- withParquetFile(data)(path => f(sqlContext.read.parquet(path)))
+ withParquetFile(data)(path => readParquetFile(path.toString, testVectorized)(f))
}
/**
@@ -72,9 +86,9 @@ private[sql] trait ParquetTest extends SQLTestUtils {
* Parquet file will be dropped/deleted after `f` returns.
*/
protected def withParquetTable[T <: Product: ClassTag: TypeTag]
- (data: Seq[T], tableName: String)
+ (data: Seq[T], tableName: String, testVectorized: Boolean = true)
(f: => Unit): Unit = {
- withParquetDataFrame(data) { df =>
+ withParquetDataFrame(data, testVectorized) { df =>
sqlContext.registerDataFrameAsTable(df, tableName)
withTempTable(tableName)(f)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala
index 7841ffe5e0..b5af758a65 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala
@@ -61,7 +61,8 @@ class HiveParquetSuite extends QueryTest with ParquetTest with TestHiveSingleton
}
test("INSERT OVERWRITE TABLE Parquet table") {
- withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") {
+ // Don't run with vectorized: currently relies on UnsafeRow.
+ withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t", false) {
withTempPath { file =>
sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath)
hiveContext.read.parquet(file.getCanonicalPath).registerTempTable("p")