aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNong Li <nong@databricks.com>2016-02-02 16:33:21 -0800
committerDavies Liu <davies.liu@gmail.com>2016-02-02 16:33:21 -0800
commit21112e8a14c042ccef4312079672108a1082a95e (patch)
treea819ca9e2707b5d60fecea003aad42f06e8905df
parent672032d0ab1e43bc5a25cecdb1b96dfd35c39778 (diff)
downloadspark-21112e8a14c042ccef4312079672108a1082a95e.tar.gz
spark-21112e8a14c042ccef4312079672108a1082a95e.tar.bz2
spark-21112e8a14c042ccef4312079672108a1082a95e.zip
[SPARK-12992] [SQL] Update parquet reader to support more types when decoding to ColumnarBatch.
This patch implements support for more types when doing the vectorized decode. There are a few more types remaining but they should be very straightforward after this. This code has a few copy and paste pieces but they are difficult to eliminate due to performance considerations. Specifically, this patch adds support for: - String, Long, Byte types - Dictionary encoding for those types. Author: Nong Li <nong@databricks.com> Closes #10908 from nongli/spark-12992.
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java146
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java45
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java160
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java5
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala82
6 files changed, 424 insertions, 21 deletions
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
index 17adfec321..b5dddb9f11 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
@@ -21,6 +21,7 @@ import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.List;
+import org.apache.commons.lang.NotImplementedException;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.parquet.Preconditions;
@@ -41,6 +42,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder;
import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter;
import org.apache.spark.sql.execution.vectorized.ColumnVector;
import org.apache.spark.sql.execution.vectorized.ColumnarBatch;
+import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.types.UTF8String;
@@ -207,13 +209,7 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
int num = (int)Math.min((long) columnarBatch.capacity(), totalRowCount - rowsReturned);
for (int i = 0; i < columnReaders.length; ++i) {
- switch (columnReaders[i].descriptor.getType()) {
- case INT32:
- columnReaders[i].readIntBatch(num, columnarBatch.column(i));
- break;
- default:
- throw new IOException("Unsupported type: " + columnReaders[i].descriptor.getType());
- }
+ columnReaders[i].readBatch(num, columnarBatch.column(i));
}
rowsReturned += num;
columnarBatch.setNumRows(num);
@@ -237,7 +233,8 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
// TODO: Be extremely cautious in what is supported. Expand this.
if (originalTypes[i] != null && originalTypes[i] != OriginalType.DECIMAL &&
- originalTypes[i] != OriginalType.UTF8 && originalTypes[i] != OriginalType.DATE) {
+ originalTypes[i] != OriginalType.UTF8 && originalTypes[i] != OriginalType.DATE &&
+ originalTypes[i] != OriginalType.INT_8 && originalTypes[i] != OriginalType.INT_16) {
throw new IOException("Unsupported type: " + t);
}
if (originalTypes[i] == OriginalType.DECIMAL &&
@@ -465,6 +462,11 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
private boolean useDictionary;
/**
+ * If useDictionary is true, the staging vector used to decode the ids.
+ */
+ private ColumnVector dictionaryIds;
+
+ /**
* Maximum definition level for this column.
*/
private final int maxDefLevel;
@@ -587,9 +589,8 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
/**
* Reads `total` values from this columnReader into column.
- * TODO: implement the other encodings.
*/
- private void readIntBatch(int total, ColumnVector column) throws IOException {
+ private void readBatch(int total, ColumnVector column) throws IOException {
int rowId = 0;
while (total > 0) {
// Compute the number of values we want to read in this page.
@@ -599,21 +600,134 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
leftInPage = (int)(endOfPageValueCount - valuesRead);
}
int num = Math.min(total, leftInPage);
- defColumn.readIntegers(
- num, column, rowId, maxDefLevel, (VectorizedValuesReader)dataColumn, 0);
-
- // Remap the values if it is dictionary encoded.
if (useDictionary) {
- for (int i = rowId; i < rowId + num; ++i) {
- column.putInt(i, dictionary.decodeToInt(column.getInt(i)));
+ // Data is dictionary encoded. We will vector decode the ids and then resolve the values.
+ if (dictionaryIds == null) {
+ dictionaryIds = ColumnVector.allocate(total, DataTypes.IntegerType, MemoryMode.ON_HEAP);
+ } else {
+ dictionaryIds.reset();
+ dictionaryIds.reserve(total);
+ }
+ // Read and decode dictionary ids.
+ readIntBatch(rowId, num, dictionaryIds);
+ decodeDictionaryIds(rowId, num, column);
+ } else {
+ switch (descriptor.getType()) {
+ case INT32:
+ readIntBatch(rowId, num, column);
+ break;
+ case INT64:
+ readLongBatch(rowId, num, column);
+ break;
+ case BINARY:
+ readBinaryBatch(rowId, num, column);
+ break;
+ default:
+ throw new IOException("Unsupported type: " + descriptor.getType());
}
}
+
valuesRead += num;
rowId += num;
total -= num;
}
}
+ /**
+ * Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`.
+ */
+ private void decodeDictionaryIds(int rowId, int num, ColumnVector column) {
+ switch (descriptor.getType()) {
+ case INT32:
+ if (column.dataType() == DataTypes.IntegerType) {
+ for (int i = rowId; i < rowId + num; ++i) {
+ column.putInt(i, dictionary.decodeToInt(dictionaryIds.getInt(i)));
+ }
+ } else if (column.dataType() == DataTypes.ByteType) {
+ for (int i = rowId; i < rowId + num; ++i) {
+ column.putByte(i, (byte)dictionary.decodeToInt(dictionaryIds.getInt(i)));
+ }
+ } else {
+ throw new NotImplementedException("Unimplemented type: " + column.dataType());
+ }
+ break;
+
+ case INT64:
+ for (int i = rowId; i < rowId + num; ++i) {
+ column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i)));
+ }
+ break;
+
+ case BINARY:
+ // TODO: this is incredibly inefficient as it blows up the dictionary right here. We
+ // need to do this better. We should probably add the dictionary data to the ColumnVector
+ // and reuse it across batches. This should mean adding a ByteArray would just update
+ // the length and offset.
+ for (int i = rowId; i < rowId + num; ++i) {
+ Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i));
+ column.putByteArray(i, v.getBytes());
+ }
+ break;
+
+ default:
+ throw new NotImplementedException("Unsupported type: " + descriptor.getType());
+ }
+
+ if (dictionaryIds.numNulls() > 0) {
+ // Copy the NULLs over.
+ // TODO: we can improve this by decoding the NULLs directly into column. This would
+ // mean we decode the int ids into `dictionaryIds` and the NULLs into `column` and then
+ // just do the ID remapping as above.
+ for (int i = 0; i < num; ++i) {
+ if (dictionaryIds.getIsNull(rowId + i)) {
+ column.putNull(rowId + i);
+ }
+ }
+ }
+ }
+
+ /**
+ * For all the read*Batch functions, reads `num` values from this columnReader into column. It
+ * is guaranteed that num is smaller than the number of values left in the current page.
+ */
+
+ private void readIntBatch(int rowId, int num, ColumnVector column) throws IOException {
+ // This is where we implement support for the valid type conversions.
+ // TODO: implement remaining type conversions
+ if (column.dataType() == DataTypes.IntegerType) {
+ defColumn.readIntegers(
+ num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn, 0);
+ } else if (column.dataType() == DataTypes.ByteType) {
+ defColumn.readBytes(
+ num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
+ } else {
+ throw new NotImplementedException("Unimplemented type: " + column.dataType());
+ }
+ }
+
+ private void readLongBatch(int rowId, int num, ColumnVector column) throws IOException {
+ // This is where we implement support for the valid type conversions.
+ // TODO: implement remaining type conversions
+ if (column.dataType() == DataTypes.LongType) {
+ defColumn.readLongs(
+ num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
+ } else {
+ throw new NotImplementedException("Unimplemented type: " + column.dataType());
+ }
+ }
+
+ private void readBinaryBatch(int rowId, int num, ColumnVector column) throws IOException {
+ // This is where we implement support for the valid type conversions.
+ // TODO: implement remaining type conversions
+ if (column.isArray()) {
+ defColumn.readBinarys(
+ num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
+ } else {
+ throw new NotImplementedException("Unimplemented type: " + column.dataType());
+ }
+ }
+
+
private void readPage() throws IOException {
DataPage page = pageReader.readPage();
// TODO: Why is this a visitor?
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
index dac0c52ebd..cec2418e46 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
@@ -18,10 +18,13 @@ package org.apache.spark.sql.execution.datasources.parquet;
import java.io.IOException;
+import org.apache.spark.sql.Column;
import org.apache.spark.sql.execution.vectorized.ColumnVector;
import org.apache.spark.unsafe.Platform;
+import org.apache.commons.lang.NotImplementedException;
import org.apache.parquet.column.values.ValuesReader;
+import org.apache.parquet.io.api.Binary;
/**
* An implementation of the Parquet PLAIN decoder that supports the vectorized interface.
@@ -52,15 +55,53 @@ public class VectorizedPlainValuesReader extends ValuesReader implements Vectori
}
@Override
- public void readIntegers(int total, ColumnVector c, int rowId) {
+ public final void readIntegers(int total, ColumnVector c, int rowId) {
c.putIntsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
offset += 4 * total;
}
@Override
- public int readInteger() {
+ public final void readLongs(int total, ColumnVector c, int rowId) {
+ c.putLongsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
+ offset += 8 * total;
+ }
+
+ @Override
+ public final void readBytes(int total, ColumnVector c, int rowId) {
+ for (int i = 0; i < total; i++) {
+ // Bytes are stored as a 4-byte little endian int. Just read the first byte.
+ // TODO: consider pushing this in ColumnVector by adding a readBytes with a stride.
+ c.putInt(rowId + i, buffer[offset]);
+ offset += 4;
+ }
+ }
+
+ @Override
+ public final int readInteger() {
int v = Platform.getInt(buffer, offset);
offset += 4;
return v;
}
+
+ @Override
+ public final long readLong() {
+ long v = Platform.getLong(buffer, offset);
+ offset += 8;
+ return v;
+ }
+
+ @Override
+ public final byte readByte() {
+ return (byte)readInteger();
+ }
+
+ @Override
+ public final void readBinary(int total, ColumnVector v, int rowId) {
+ for (int i = 0; i < total; i++) {
+ int len = readInteger();
+ int start = offset;
+ offset += len;
+ v.putByteArray(rowId + i, buffer, start - Platform.BYTE_ARRAY_OFFSET, len);
+ }
+ }
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
index 493ec9deed..9bfd74db38 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
@@ -17,12 +17,16 @@
package org.apache.spark.sql.execution.datasources.parquet;
+import org.apache.commons.lang.NotImplementedException;
import org.apache.parquet.Preconditions;
import org.apache.parquet.bytes.BytesUtils;
import org.apache.parquet.column.values.ValuesReader;
import org.apache.parquet.column.values.bitpacking.BytePacker;
import org.apache.parquet.column.values.bitpacking.Packer;
import org.apache.parquet.io.ParquetDecodingException;
+import org.apache.parquet.io.api.Binary;
+
+import org.apache.spark.sql.Column;
import org.apache.spark.sql.execution.vectorized.ColumnVector;
/**
@@ -35,7 +39,8 @@ import org.apache.spark.sql.execution.vectorized.ColumnVector;
* - Definition/Repetition levels
* - Dictionary ids.
*/
-public final class VectorizedRleValuesReader extends ValuesReader {
+public final class VectorizedRleValuesReader extends ValuesReader
+ implements VectorizedValuesReader {
// Current decoding mode. The encoded data contains groups of either run length encoded data
// (RLE) or bit packed data. Each group contains a header that indicates which group it is and
// the number of values in the group.
@@ -121,6 +126,7 @@ public final class VectorizedRleValuesReader extends ValuesReader {
return readInteger();
}
+
@Override
public int readInteger() {
if (this.currentCount == 0) { this.readNextGroup(); }
@@ -138,7 +144,9 @@ public final class VectorizedRleValuesReader extends ValuesReader {
/**
* Reads `total` ints into `c` filling them in starting at `c[rowId]`. This reader
* reads the definition levels and then will read from `data` for the non-null values.
- * If the value is null, c will be populated with `nullValue`.
+ * If the value is null, c will be populated with `nullValue`. Note that `nullValue` is only
+ * necessary for readIntegers because we also use it to decode dictionaryIds and want to make
+ * sure it always has a value in range.
*
* This is a batched version of this logic:
* if (this.readInt() == level) {
@@ -180,6 +188,154 @@ public final class VectorizedRleValuesReader extends ValuesReader {
}
}
+ // TODO: can this code duplication be removed without a perf penalty?
+ public void readBytes(int total, ColumnVector c,
+ int rowId, int level, VectorizedValuesReader data) {
+ int left = total;
+ while (left > 0) {
+ if (this.currentCount == 0) this.readNextGroup();
+ int n = Math.min(left, this.currentCount);
+ switch (mode) {
+ case RLE:
+ if (currentValue == level) {
+ data.readBytes(n, c, rowId);
+ c.putNotNulls(rowId, n);
+ } else {
+ c.putNulls(rowId, n);
+ }
+ break;
+ case PACKED:
+ for (int i = 0; i < n; ++i) {
+ if (currentBuffer[currentBufferIdx++] == level) {
+ c.putByte(rowId + i, data.readByte());
+ c.putNotNull(rowId + i);
+ } else {
+ c.putNull(rowId + i);
+ }
+ }
+ break;
+ }
+ rowId += n;
+ left -= n;
+ currentCount -= n;
+ }
+ }
+
+ public void readLongs(int total, ColumnVector c, int rowId, int level,
+ VectorizedValuesReader data) {
+ int left = total;
+ while (left > 0) {
+ if (this.currentCount == 0) this.readNextGroup();
+ int n = Math.min(left, this.currentCount);
+ switch (mode) {
+ case RLE:
+ if (currentValue == level) {
+ data.readLongs(n, c, rowId);
+ c.putNotNulls(rowId, n);
+ } else {
+ c.putNulls(rowId, n);
+ }
+ break;
+ case PACKED:
+ for (int i = 0; i < n; ++i) {
+ if (currentBuffer[currentBufferIdx++] == level) {
+ c.putLong(rowId + i, data.readLong());
+ c.putNotNull(rowId + i);
+ } else {
+ c.putNull(rowId + i);
+ }
+ }
+ break;
+ }
+ rowId += n;
+ left -= n;
+ currentCount -= n;
+ }
+ }
+
+ public void readBinarys(int total, ColumnVector c, int rowId, int level,
+ VectorizedValuesReader data) {
+ int left = total;
+ while (left > 0) {
+ if (this.currentCount == 0) this.readNextGroup();
+ int n = Math.min(left, this.currentCount);
+ switch (mode) {
+ case RLE:
+ if (currentValue == level) {
+ c.putNotNulls(rowId, n);
+ data.readBinary(n, c, rowId);
+ } else {
+ c.putNulls(rowId, n);
+ }
+ break;
+ case PACKED:
+ for (int i = 0; i < n; ++i) {
+ if (currentBuffer[currentBufferIdx++] == level) {
+ c.putNotNull(rowId + i);
+ data.readBinary(1, c, rowId);
+ } else {
+ c.putNull(rowId + i);
+ }
+ }
+ break;
+ }
+ rowId += n;
+ left -= n;
+ currentCount -= n;
+ }
+ }
+
+
+ // The RLE reader implements the vectorized decoding interface when used to decode dictionary
+ // IDs. This is different than the above APIs that decodes definitions levels along with values.
+ // Since this is only used to decode dictionary IDs, only decoding integers is supported.
+ @Override
+ public void readIntegers(int total, ColumnVector c, int rowId) {
+ int left = total;
+ while (left > 0) {
+ if (this.currentCount == 0) this.readNextGroup();
+ int n = Math.min(left, this.currentCount);
+ switch (mode) {
+ case RLE:
+ c.putInts(rowId, n, currentValue);
+ break;
+ case PACKED:
+ c.putInts(rowId, n, currentBuffer, currentBufferIdx);
+ currentBufferIdx += n;
+ break;
+ }
+ rowId += n;
+ left -= n;
+ currentCount -= n;
+ }
+ }
+
+ @Override
+ public byte readByte() {
+ throw new UnsupportedOperationException("only readInts is valid.");
+ }
+
+ @Override
+ public void readBytes(int total, ColumnVector c, int rowId) {
+ throw new UnsupportedOperationException("only readInts is valid.");
+ }
+
+ @Override
+ public void readLongs(int total, ColumnVector c, int rowId) {
+ throw new UnsupportedOperationException("only readInts is valid.");
+ }
+
+ @Override
+ public void readBinary(int total, ColumnVector c, int rowId) {
+ throw new UnsupportedOperationException("only readInts is valid.");
+ }
+
+ @Override
+ public void skip(int n) {
+ throw new UnsupportedOperationException("only readInts is valid.");
+ }
+
+
/**
* Reads the next varint encoded int.
*/
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java
index 49a9ed83d5..b6ec7311c5 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java
@@ -24,12 +24,17 @@ import org.apache.spark.sql.execution.vectorized.ColumnVector;
* TODO: merge this into parquet-mr.
*/
public interface VectorizedValuesReader {
+ byte readByte();
int readInteger();
+ long readLong();
/*
* Reads `total` values into `c` start at `c[rowId]`
*/
+ void readBytes(int total, ColumnVector c, int rowId);
void readIntegers(int total, ColumnVector c, int rowId);
+ void readLongs(int total, ColumnVector c, int rowId);
+ void readBinary(int total, ColumnVector c, int rowId);
// TODO: add all the other parquet types.
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
index a5bc506a65..0514252a8e 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
@@ -763,7 +763,12 @@ public abstract class ColumnVector {
/**
* Returns the elements appended.
*/
- public int getElementsAppended() { return elementsAppended; }
+ public final int getElementsAppended() { return elementsAppended; }
+
+ /**
+ * Returns true if this column is an array.
+ */
+ public final boolean isArray() { return resultArray != null; }
/**
* Maximum number of rows that can be stored in this column.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala
new file mode 100644
index 0000000000..cef6b79a09
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.parquet
+
+import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils
+import org.apache.spark.sql.test.SharedSQLContext
+
+// TODO: this needs a lot more testing but it's currently not easy to test with the parquet
+// writer abstractions. Revisit.
+class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContext {
+ import testImplicits._
+
+ val ROW = ((1).toByte, 2, 3L, "abc")
+ val NULL_ROW = (
+ null.asInstanceOf[java.lang.Byte],
+ null.asInstanceOf[Integer],
+ null.asInstanceOf[java.lang.Long],
+ null.asInstanceOf[String])
+
+ test("All Types Dictionary") {
+ (1 :: 1000 :: Nil).foreach { n => {
+ withTempPath { dir =>
+ List.fill(n)(ROW).toDF.repartition(1).write.parquet(dir.getCanonicalPath)
+ val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head
+
+ val reader = new UnsafeRowParquetRecordReader
+ reader.initialize(file.asInstanceOf[String], null)
+ val batch = reader.resultBatch()
+ assert(reader.nextBatch())
+ assert(batch.numRows() == n)
+ var i = 0
+ while (i < n) {
+ assert(batch.column(0).getByte(i) == 1)
+ assert(batch.column(1).getInt(i) == 2)
+ assert(batch.column(2).getLong(i) == 3)
+ assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(i)) == "abc")
+ i += 1
+ }
+ reader.close()
+ }
+ }}
+ }
+
+ test("All Types Null") {
+ (1 :: 100 :: Nil).foreach { n => {
+ withTempPath { dir =>
+ val data = List.fill(n)(NULL_ROW).toDF
+ data.repartition(1).write.parquet(dir.getCanonicalPath)
+ val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head
+
+ val reader = new UnsafeRowParquetRecordReader
+ reader.initialize(file.asInstanceOf[String], null)
+ val batch = reader.resultBatch()
+ assert(reader.nextBatch())
+ assert(batch.numRows() == n)
+ var i = 0
+ while (i < n) {
+ assert(batch.column(0).getIsNull(i))
+ assert(batch.column(1).getIsNull(i))
+ assert(batch.column(2).getIsNull(i))
+ assert(batch.column(3).getIsNull(i))
+ i += 1
+ }
+ reader.close()
+ }}
+ }
+ }
+}