aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNong Li <nong@databricks.com>2015-11-18 18:38:45 -0800
committerReynold Xin <rxin@databricks.com>2015-11-18 18:38:45 -0800
commit6d0848b53bbe6c5acdcf5c033cd396b1ae6e293d (patch)
treecf1c2b5184a996d4e931d1837dd7899199c2ba72
parente61367b9f9bfc8e123369d55d7ca5925568b98a7 (diff)
downloadspark-6d0848b53bbe6c5acdcf5c033cd396b1ae6e293d.tar.gz
spark-6d0848b53bbe6c5acdcf5c033cd396b1ae6e293d.tar.bz2
spark-6d0848b53bbe6c5acdcf5c033cd396b1ae6e293d.zip
[SPARK-11787][SQL] Improve Parquet scan performance when using flat schemas.
This patch adds an alternate to the Parquet RecordReader from the parquet-mr project that is much faster for flat schemas. Instead of using the general converter mechanism from parquet-mr, this directly uses the lower level APIs from parquet-columnar and a customer RecordReader that directly assembles into UnsafeRows. This is optionally disabled and only used for supported schemas. Using the tpcds store sales table and doing a sum of increasingly more columns, the results are: For 1 Column: Before: 11.3M rows/second After: 18.2M rows/second For 2 Columns: Before: 7.2M rows/second After: 11.2M rows/second For 5 Columns: Before: 2.9M rows/second After: 4.5M rows/second Author: Nong Li <nong@databricks.com> Closes #9774 from nongli/parquet.
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala41
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java9
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java32
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java20
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java240
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java593
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala48
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala4
8 files changed, 944 insertions, 43 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala
index 264dae7f39..4d176332b6 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala
@@ -20,8 +20,6 @@ package org.apache.spark.rdd
import java.text.SimpleDateFormat
import java.util.Date
-import scala.reflect.ClassTag
-
import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
@@ -30,10 +28,12 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.executor.DataReadMethod
import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
+import org.apache.spark.storage.StorageLevel
import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.{Utils, SerializableConfiguration, ShutdownHookManager}
import org.apache.spark.{Partition => SparkPartition, _}
-import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Utils}
+
+import scala.reflect.ClassTag
private[spark] class SqlNewHadoopPartition(
@@ -96,6 +96,11 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
@transient protected val jobId = new JobID(jobTrackerId, id)
+ // If true, enable using the custom RecordReader for parquet. This only works for
+ // a subset of the types (no complex types).
+ protected val enableUnsafeRowParquetReader: Boolean =
+ sc.conf.getBoolean("spark.parquet.enableUnsafeRowRecordReader", true)
+
override def getPartitions: Array[SparkPartition] = {
val conf = getConf(isDriverSide = true)
val inputFormat = inputFormatClass.newInstance
@@ -150,9 +155,31 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
configurable.setConf(conf)
case _ =>
}
- private[this] var reader = format.createRecordReader(
- split.serializableHadoopSplit.value, hadoopAttemptContext)
- reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
+ private[this] var reader: RecordReader[Void, V] = null
+
+ /**
+ * If the format is ParquetInputFormat, try to create the optimized RecordReader. If this
+ * fails (for example, unsupported schema), try with the normal reader.
+ * TODO: plumb this through a different way?
+ */
+ if (enableUnsafeRowParquetReader &&
+ format.getClass.getName == "org.apache.parquet.hadoop.ParquetInputFormat") {
+ // TODO: move this class to sql.execution and remove this.
+ reader = Utils.classForName(
+ "org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader")
+ .newInstance().asInstanceOf[RecordReader[Void, V]]
+ try {
+ reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
+ } catch {
+ case e: Exception => reader = null
+ }
+ }
+
+ if (reader == null) {
+ reader = format.createRecordReader(
+ split.serializableHadoopSplit.value, hadoopAttemptContext)
+ reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
+ }
// Register an on-task-completion callback to close the input stream.
context.addTaskCompletionListener(context => close())
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 5ba14ebdb6..33769363a0 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -178,6 +178,15 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
pointTo(buf, Platform.BYTE_ARRAY_OFFSET, numFields, sizeInBytes);
}
+ /**
+ * Updates this UnsafeRow preserving the number of fields.
+ * @param buf byte array to point to
+ * @param sizeInBytes the number of bytes valid in the byte array
+ */
+ public void pointTo(byte[] buf, int sizeInBytes) {
+ pointTo(buf, numFields, sizeInBytes);
+ }
+
@Override
public void setNullAt(int i) {
assertIndexIsValid(i);
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java
index 9c94686780..d26b1b187c 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java
@@ -17,19 +17,28 @@
package org.apache.spark.sql.catalyst.expressions.codegen;
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.unsafe.Platform;
/**
- * A helper class to manage the row buffer used in `GenerateUnsafeProjection`.
- *
- * Note that it is only used in `GenerateUnsafeProjection`, so it's safe to mark member variables
- * public for ease of use.
+ * A helper class to manage the row buffer when construct unsafe rows.
*/
public class BufferHolder {
- public byte[] buffer = new byte[64];
+ public byte[] buffer;
public int cursor = Platform.BYTE_ARRAY_OFFSET;
- public void grow(int neededSize) {
+ public BufferHolder() {
+ this(64);
+ }
+
+ public BufferHolder(int size) {
+ buffer = new byte[size];
+ }
+
+ /**
+ * Grows the buffer to at least neededSize. If row is non-null, points the row to the buffer.
+ */
+ public void grow(int neededSize, UnsafeRow row) {
final int length = totalSize() + neededSize;
if (buffer.length < length) {
// This will not happen frequently, because the buffer is re-used.
@@ -41,12 +50,23 @@ public class BufferHolder {
Platform.BYTE_ARRAY_OFFSET,
totalSize());
buffer = tmp;
+ if (row != null) {
+ row.pointTo(buffer, length * 2);
+ }
}
}
+ public void grow(int neededSize) {
+ grow(neededSize, null);
+ }
+
public void reset() {
cursor = Platform.BYTE_ARRAY_OFFSET;
}
+ public void resetTo(int offset) {
+ assert(offset <= buffer.length);
+ cursor = Platform.BYTE_ARRAY_OFFSET + offset;
+ }
public int totalSize() {
return cursor - Platform.BYTE_ARRAY_OFFSET;
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
index 048b7749d8..e227c0dec9 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
@@ -35,6 +35,7 @@ public class UnsafeRowWriter {
// The offset of the global buffer where we start to write this row.
private int startingOffset;
private int nullBitsSize;
+ private UnsafeRow row;
public void initialize(BufferHolder holder, int numFields) {
this.holder = holder;
@@ -43,7 +44,7 @@ public class UnsafeRowWriter {
// grow the global buffer to make sure it has enough space to write fixed-length data.
final int fixedSize = nullBitsSize + 8 * numFields;
- holder.grow(fixedSize);
+ holder.grow(fixedSize, row);
holder.cursor += fixedSize;
// zero-out the null bits region
@@ -52,12 +53,19 @@ public class UnsafeRowWriter {
}
}
+ public void initialize(UnsafeRow row, BufferHolder holder, int numFields) {
+ initialize(holder, numFields);
+ this.row = row;
+ }
+
private void zeroOutPaddingBytes(int numBytes) {
if ((numBytes & 0x07) > 0) {
Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L);
}
}
+ public BufferHolder holder() { return holder; }
+
public boolean isNullAt(int ordinal) {
return BitSetMethods.isSet(holder.buffer, startingOffset, ordinal);
}
@@ -90,7 +98,7 @@ public class UnsafeRowWriter {
if (remainder > 0) {
final int paddingBytes = 8 - remainder;
- holder.grow(paddingBytes);
+ holder.grow(paddingBytes, row);
for (int i = 0; i < paddingBytes; i++) {
Platform.putByte(holder.buffer, holder.cursor, (byte) 0);
@@ -153,7 +161,7 @@ public class UnsafeRowWriter {
}
} else {
// grow the global buffer before writing data.
- holder.grow(16);
+ holder.grow(16, row);
// zero-out the bytes
Platform.putLong(holder.buffer, holder.cursor, 0L);
@@ -185,7 +193,7 @@ public class UnsafeRowWriter {
final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
// grow the global buffer before writing data.
- holder.grow(roundedSize);
+ holder.grow(roundedSize, row);
zeroOutPaddingBytes(numBytes);
@@ -206,7 +214,7 @@ public class UnsafeRowWriter {
final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
// grow the global buffer before writing data.
- holder.grow(roundedSize);
+ holder.grow(roundedSize, row);
zeroOutPaddingBytes(numBytes);
@@ -222,7 +230,7 @@ public class UnsafeRowWriter {
public void write(int ordinal, CalendarInterval input) {
// grow the global buffer before writing data.
- holder.grow(16);
+ holder.grow(16, row);
// Write the months and microseconds fields of Interval to the variable length portion.
Platform.putLong(holder.buffer, holder.cursor, input.months);
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
new file mode 100644
index 0000000000..2ed30c1f5a
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
@@ -0,0 +1,240 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.spark.sql.execution.datasources.parquet;
+
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import static org.apache.parquet.filter2.compat.RowGroupFilter.filterRowGroups;
+import static org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER;
+import static org.apache.parquet.format.converter.ParquetMetadataConverter.range;
+import static org.apache.parquet.hadoop.ParquetFileReader.readFooter;
+import static org.apache.parquet.hadoop.ParquetInputFormat.getFilter;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.mapreduce.InputSplit;
+import org.apache.hadoop.mapreduce.RecordReader;
+import org.apache.hadoop.mapreduce.TaskAttemptContext;
+import org.apache.parquet.bytes.BytesInput;
+import org.apache.parquet.bytes.BytesUtils;
+import org.apache.parquet.column.ColumnDescriptor;
+import org.apache.parquet.column.values.ValuesReader;
+import org.apache.parquet.column.values.rle.RunLengthBitPackingHybridDecoder;
+import org.apache.parquet.filter2.compat.FilterCompat;
+import org.apache.parquet.hadoop.BadConfigurationException;
+import org.apache.parquet.hadoop.ParquetFileReader;
+import org.apache.parquet.hadoop.ParquetInputFormat;
+import org.apache.parquet.hadoop.ParquetInputSplit;
+import org.apache.parquet.hadoop.api.InitContext;
+import org.apache.parquet.hadoop.api.ReadSupport;
+import org.apache.parquet.hadoop.metadata.BlockMetaData;
+import org.apache.parquet.hadoop.metadata.ParquetMetadata;
+import org.apache.parquet.hadoop.util.ConfigurationUtil;
+import org.apache.parquet.schema.MessageType;
+
+/**
+ * Base class for custom RecordReaaders for Parquet that directly materialize to `T`.
+ * This class handles computing row groups, filtering on them, setting up the column readers,
+ * etc.
+ * This is heavily based on parquet-mr's RecordReader.
+ * TODO: move this to the parquet-mr project. There are performance benefits of doing it
+ * this way, albeit at a higher cost to implement. This base class is reusable.
+ */
+public abstract class SpecificParquetRecordReaderBase<T> extends RecordReader<Void, T> {
+ protected Path file;
+ protected MessageType fileSchema;
+ protected MessageType requestedSchema;
+ protected ReadSupport<T> readSupport;
+
+ /**
+ * The total number of rows this RecordReader will eventually read. The sum of the
+ * rows of all the row groups.
+ */
+ protected long totalRowCount;
+
+ protected ParquetFileReader reader;
+
+ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext)
+ throws IOException, InterruptedException {
+ Configuration configuration = taskAttemptContext.getConfiguration();
+ ParquetInputSplit split = (ParquetInputSplit)inputSplit;
+ this.file = split.getPath();
+ long[] rowGroupOffsets = split.getRowGroupOffsets();
+
+ ParquetMetadata footer;
+ List<BlockMetaData> blocks;
+
+ // if task.side.metadata is set, rowGroupOffsets is null
+ if (rowGroupOffsets == null) {
+ // then we need to apply the predicate push down filter
+ footer = readFooter(configuration, file, range(split.getStart(), split.getEnd()));
+ MessageType fileSchema = footer.getFileMetaData().getSchema();
+ FilterCompat.Filter filter = getFilter(configuration);
+ blocks = filterRowGroups(filter, footer.getBlocks(), fileSchema);
+ } else {
+ // otherwise we find the row groups that were selected on the client
+ footer = readFooter(configuration, file, NO_FILTER);
+ Set<Long> offsets = new HashSet<>();
+ for (long offset : rowGroupOffsets) {
+ offsets.add(offset);
+ }
+ blocks = new ArrayList<>();
+ for (BlockMetaData block : footer.getBlocks()) {
+ if (offsets.contains(block.getStartingPos())) {
+ blocks.add(block);
+ }
+ }
+ // verify we found them all
+ if (blocks.size() != rowGroupOffsets.length) {
+ long[] foundRowGroupOffsets = new long[footer.getBlocks().size()];
+ for (int i = 0; i < foundRowGroupOffsets.length; i++) {
+ foundRowGroupOffsets[i] = footer.getBlocks().get(i).getStartingPos();
+ }
+ // this should never happen.
+ // provide a good error message in case there's a bug
+ throw new IllegalStateException(
+ "All the offsets listed in the split should be found in the file."
+ + " expected: " + Arrays.toString(rowGroupOffsets)
+ + " found: " + blocks
+ + " out of: " + Arrays.toString(foundRowGroupOffsets)
+ + " in range " + split.getStart() + ", " + split.getEnd());
+ }
+ }
+ MessageType fileSchema = footer.getFileMetaData().getSchema();
+ Map<String, String> fileMetadata = footer.getFileMetaData().getKeyValueMetaData();
+ this.readSupport = getReadSupportInstance(
+ (Class<? extends ReadSupport<T>>) getReadSupportClass(configuration));
+ ReadSupport.ReadContext readContext = readSupport.init(new InitContext(
+ taskAttemptContext.getConfiguration(), toSetMultiMap(fileMetadata), fileSchema));
+ this.requestedSchema = readContext.getRequestedSchema();
+ this.fileSchema = fileSchema;
+ this.reader = new ParquetFileReader(configuration, file, blocks, requestedSchema.getColumns());
+ for (BlockMetaData block : blocks) {
+ this.totalRowCount += block.getRowCount();
+ }
+ }
+
+ @Override
+ public Void getCurrentKey() throws IOException, InterruptedException {
+ return null;
+ }
+
+ @Override
+ public void close() throws IOException {
+ if (reader != null) {
+ reader.close();
+ reader = null;
+ }
+ }
+
+ /**
+ * Utility classes to abstract over different way to read ints with different encodings.
+ * TODO: remove this layer of abstraction?
+ */
+ abstract static class IntIterator {
+ abstract int nextInt() throws IOException;
+ }
+
+ protected static final class ValuesReaderIntIterator extends IntIterator {
+ ValuesReader delegate;
+
+ public ValuesReaderIntIterator(ValuesReader delegate) {
+ this.delegate = delegate;
+ }
+
+ @Override
+ int nextInt() throws IOException {
+ return delegate.readInteger();
+ }
+ }
+
+ protected static final class RLEIntIterator extends IntIterator {
+ RunLengthBitPackingHybridDecoder delegate;
+
+ public RLEIntIterator(RunLengthBitPackingHybridDecoder delegate) {
+ this.delegate = delegate;
+ }
+
+ @Override
+ int nextInt() throws IOException {
+ return delegate.readInt();
+ }
+ }
+
+ protected static final class NullIntIterator extends IntIterator {
+ @Override
+ int nextInt() throws IOException { return 0; }
+ }
+
+ /**
+ * Creates a reader for definition and repetition levels, returning an optimized one if
+ * the levels are not needed.
+ */
+ static protected IntIterator createRLEIterator(int maxLevel, BytesInput bytes,
+ ColumnDescriptor descriptor) throws IOException {
+ try {
+ if (maxLevel == 0) return new NullIntIterator();
+ return new RLEIntIterator(
+ new RunLengthBitPackingHybridDecoder(
+ BytesUtils.getWidthFromMaxInt(maxLevel),
+ new ByteArrayInputStream(bytes.toByteArray())));
+ } catch (IOException e) {
+ throw new IOException("could not read levels in page for col " + descriptor, e);
+ }
+ }
+
+ private static <K, V> Map<K, Set<V>> toSetMultiMap(Map<K, V> map) {
+ Map<K, Set<V>> setMultiMap = new HashMap<>();
+ for (Map.Entry<K, V> entry : map.entrySet()) {
+ Set<V> set = new HashSet<>();
+ set.add(entry.getValue());
+ setMultiMap.put(entry.getKey(), Collections.unmodifiableSet(set));
+ }
+ return Collections.unmodifiableMap(setMultiMap);
+ }
+
+ private static Class<?> getReadSupportClass(Configuration configuration) {
+ return ConfigurationUtil.getClassFromConfig(configuration,
+ ParquetInputFormat.READ_SUPPORT_CLASS, ReadSupport.class);
+ }
+
+ /**
+ * @param readSupportClass to instantiate
+ * @return the configured read support
+ */
+ private static <T> ReadSupport<T> getReadSupportInstance(
+ Class<? extends ReadSupport<T>> readSupportClass){
+ try {
+ return readSupportClass.newInstance();
+ } catch (InstantiationException e) {
+ throw new BadConfigurationException("could not instantiate read support class", e);
+ } catch (IllegalAccessException e) {
+ throw new BadConfigurationException("could not instantiate read support class", e);
+ }
+ }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
new file mode 100644
index 0000000000..8a92e489cc
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
@@ -0,0 +1,593 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.parquet;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.List;
+
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder;
+import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter;
+import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.types.UTF8String;
+
+import static org.apache.parquet.column.ValuesType.DEFINITION_LEVEL;
+import static org.apache.parquet.column.ValuesType.REPETITION_LEVEL;
+import static org.apache.parquet.column.ValuesType.VALUES;
+
+import org.apache.hadoop.mapreduce.InputSplit;
+import org.apache.hadoop.mapreduce.TaskAttemptContext;
+import org.apache.parquet.Preconditions;
+import org.apache.parquet.column.ColumnDescriptor;
+import org.apache.parquet.column.Dictionary;
+import org.apache.parquet.column.Encoding;
+import org.apache.parquet.column.page.DataPage;
+import org.apache.parquet.column.page.DataPageV1;
+import org.apache.parquet.column.page.DataPageV2;
+import org.apache.parquet.column.page.DictionaryPage;
+import org.apache.parquet.column.page.PageReadStore;
+import org.apache.parquet.column.page.PageReader;
+import org.apache.parquet.column.values.ValuesReader;
+import org.apache.parquet.io.api.Binary;
+import org.apache.parquet.schema.OriginalType;
+import org.apache.parquet.schema.PrimitiveType;
+import org.apache.parquet.schema.Type;
+
+/**
+ * A specialized RecordReader that reads into UnsafeRows directly using the Parquet column APIs.
+ *
+ * This is somewhat based on parquet-mr's ColumnReader.
+ *
+ * TODO: handle complex types, decimal requiring more than 8 bytes, INT96. Schema mismatch.
+ * All of these can be handled efficiently and easily with codegen.
+ */
+public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBase<UnsafeRow> {
+ /**
+ * Batch of unsafe rows that we assemble and the current index we've returned. Everytime this
+ * batch is used up (batchIdx == numBatched), we populated the batch.
+ */
+ private UnsafeRow[] rows = new UnsafeRow[64];
+ private int batchIdx = 0;
+ private int numBatched = 0;
+
+ /**
+ * Used to write variable length columns. Same length as `rows`.
+ */
+ private UnsafeRowWriter[] rowWriters = null;
+ /**
+ * True if the row contains variable length fields.
+ */
+ private boolean containsVarLenFields;
+
+ /**
+ * The number of bytes in the fixed length portion of the row.
+ */
+ private int fixedSizeBytes;
+
+ /**
+ * For each request column, the reader to read this column.
+ * columnsReaders[i] populated the UnsafeRow's attribute at i.
+ */
+ private ColumnReader[] columnReaders;
+
+ /**
+ * The number of rows that have been returned.
+ */
+ private long rowsReturned;
+
+ /**
+ * The number of rows that have been reading, including the current in flight row group.
+ */
+ private long totalCountLoadedSoFar = 0;
+
+ /**
+ * For each column, the annotated original type.
+ */
+ private OriginalType[] originalTypes;
+
+ /**
+ * The default size for varlen columns. The row grows as necessary to accommodate the
+ * largest column.
+ */
+ private static final int DEFAULT_VAR_LEN_SIZE = 32;
+
+ /**
+ * Implementation of RecordReader API.
+ */
+ @Override
+ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext)
+ throws IOException, InterruptedException {
+ super.initialize(inputSplit, taskAttemptContext);
+
+ /**
+ * Check that the requested schema is supported.
+ */
+ if (requestedSchema.getFieldCount() == 0) {
+ // TODO: what does this mean?
+ throw new IOException("Empty request schema not supported.");
+ }
+ int numVarLenFields = 0;
+ originalTypes = new OriginalType[requestedSchema.getFieldCount()];
+ for (int i = 0; i < requestedSchema.getFieldCount(); ++i) {
+ Type t = requestedSchema.getFields().get(i);
+ if (!t.isPrimitive() || t.isRepetition(Type.Repetition.REPEATED)) {
+ throw new IOException("Complex types not supported.");
+ }
+ PrimitiveType primitiveType = t.asPrimitiveType();
+
+ originalTypes[i] = t.getOriginalType();
+
+ // TODO: Be extremely cautious in what is supported. Expand this.
+ if (originalTypes[i] != null && originalTypes[i] != OriginalType.DECIMAL &&
+ originalTypes[i] != OriginalType.UTF8 && originalTypes[i] != OriginalType.DATE) {
+ throw new IOException("Unsupported type: " + t);
+ }
+ if (originalTypes[i] == OriginalType.DECIMAL &&
+ primitiveType.getDecimalMetadata().getPrecision() >
+ CatalystSchemaConverter.MAX_PRECISION_FOR_INT64()) {
+ throw new IOException("Decimal with high precision is not supported.");
+ }
+ if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT96) {
+ throw new IOException("Int96 not supported.");
+ }
+ ColumnDescriptor fd = fileSchema.getColumnDescription(requestedSchema.getPaths().get(i));
+ if (!fd.equals(requestedSchema.getColumns().get(i))) {
+ throw new IOException("Schema evolution not supported.");
+ }
+
+ if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.BINARY) {
+ ++numVarLenFields;
+ }
+ }
+
+ /**
+ * Initialize rows and rowWriters. These objects are reused across all rows in the relation.
+ */
+ int rowByteSize = UnsafeRow.calculateBitSetWidthInBytes(requestedSchema.getFieldCount());
+ rowByteSize += 8 * requestedSchema.getFieldCount();
+ fixedSizeBytes = rowByteSize;
+ rowByteSize += numVarLenFields * DEFAULT_VAR_LEN_SIZE;
+ containsVarLenFields = numVarLenFields > 0;
+ rowWriters = new UnsafeRowWriter[rows.length];
+
+ for (int i = 0; i < rows.length; ++i) {
+ rows[i] = new UnsafeRow();
+ rowWriters[i] = new UnsafeRowWriter();
+ BufferHolder holder = new BufferHolder(rowByteSize);
+ rowWriters[i].initialize(rows[i], holder, requestedSchema.getFieldCount());
+ rows[i].pointTo(holder.buffer, Platform.BYTE_ARRAY_OFFSET, requestedSchema.getFieldCount(),
+ holder.buffer.length);
+ }
+ }
+
+ @Override
+ public boolean nextKeyValue() throws IOException, InterruptedException {
+ if (batchIdx >= numBatched) {
+ if (!loadBatch()) return false;
+ }
+ ++batchIdx;
+ return true;
+ }
+
+ @Override
+ public UnsafeRow getCurrentValue() throws IOException, InterruptedException {
+ return rows[batchIdx - 1];
+ }
+
+ @Override
+ public float getProgress() throws IOException, InterruptedException {
+ return (float) rowsReturned / totalRowCount;
+ }
+
+ /**
+ * Decodes a batch of values into `rows`. This function is the hot path.
+ */
+ private boolean loadBatch() throws IOException {
+ // no more records left
+ if (rowsReturned >= totalRowCount) { return false; }
+ checkEndOfRowGroup();
+
+ int num = (int)Math.min(rows.length, totalCountLoadedSoFar - rowsReturned);
+ rowsReturned += num;
+
+ if (containsVarLenFields) {
+ for (int i = 0; i < rowWriters.length; ++i) {
+ rowWriters[i].holder().resetTo(fixedSizeBytes);
+ }
+ }
+
+ for (int i = 0; i < columnReaders.length; ++i) {
+ switch (columnReaders[i].descriptor.getType()) {
+ case BOOLEAN:
+ decodeBooleanBatch(i, num);
+ break;
+ case INT32:
+ if (originalTypes[i] == OriginalType.DECIMAL) {
+ decodeIntAsDecimalBatch(i, num);
+ } else {
+ decodeIntBatch(i, num);
+ }
+ break;
+ case INT64:
+ Preconditions.checkState(originalTypes[i] == null
+ || originalTypes[i] == OriginalType.DECIMAL,
+ "Unexpected original type: " + originalTypes[i]);
+ decodeLongBatch(i, num);
+ break;
+ case FLOAT:
+ decodeFloatBatch(i, num);
+ break;
+ case DOUBLE:
+ decodeDoubleBatch(i, num);
+ break;
+ case BINARY:
+ decodeBinaryBatch(i, num);
+ break;
+ case FIXED_LEN_BYTE_ARRAY:
+ Preconditions.checkState(originalTypes[i] == OriginalType.DECIMAL,
+ "Unexpected original type: " + originalTypes[i]);
+ decodeFixedLenArrayAsDecimalBatch(i, num);
+ break;
+ case INT96:
+ throw new IOException("Unsupported " + columnReaders[i].descriptor.getType());
+ }
+ numBatched = num;
+ batchIdx = 0;
+ }
+ return true;
+ }
+
+ private void decodeBooleanBatch(int col, int num) throws IOException {
+ for (int n = 0; n < num; ++n) {
+ if (columnReaders[col].next()) {
+ rows[n].setBoolean(col, columnReaders[col].nextBoolean());
+ } else {
+ rows[n].setNullAt(col);
+ }
+ }
+ }
+
+ private void decodeIntBatch(int col, int num) throws IOException {
+ for (int n = 0; n < num; ++n) {
+ if (columnReaders[col].next()) {
+ rows[n].setInt(col, columnReaders[col].nextInt());
+ } else {
+ rows[n].setNullAt(col);
+ }
+ }
+ }
+
+ private void decodeIntAsDecimalBatch(int col, int num) throws IOException {
+ for (int n = 0; n < num; ++n) {
+ if (columnReaders[col].next()) {
+ // Since this is stored as an INT, it is always a compact decimal. Just set it as a long.
+ rows[n].setLong(col, columnReaders[col].nextInt());
+ } else {
+ rows[n].setNullAt(col);
+ }
+ }
+ }
+
+ private void decodeLongBatch(int col, int num) throws IOException {
+ for (int n = 0; n < num; ++n) {
+ if (columnReaders[col].next()) {
+ rows[n].setLong(col, columnReaders[col].nextLong());
+ } else {
+ rows[n].setNullAt(col);
+ }
+ }
+ }
+
+ private void decodeFloatBatch(int col, int num) throws IOException {
+ for (int n = 0; n < num; ++n) {
+ if (columnReaders[col].next()) {
+ rows[n].setFloat(col, columnReaders[col].nextFloat());
+ } else {
+ rows[n].setNullAt(col);
+ }
+ }
+ }
+
+ private void decodeDoubleBatch(int col, int num) throws IOException {
+ for (int n = 0; n < num; ++n) {
+ if (columnReaders[col].next()) {
+ rows[n].setDouble(col, columnReaders[col].nextDouble());
+ } else {
+ rows[n].setNullAt(col);
+ }
+ }
+ }
+
+ private void decodeBinaryBatch(int col, int num) throws IOException {
+ for (int n = 0; n < num; ++n) {
+ if (columnReaders[col].next()) {
+ ByteBuffer bytes = columnReaders[col].nextBinary().toByteBuffer();
+ int len = bytes.limit() - bytes.position();
+ if (originalTypes[col] == OriginalType.UTF8) {
+ UTF8String str = UTF8String.fromBytes(bytes.array(), bytes.position(), len);
+ rowWriters[n].write(col, str);
+ } else {
+ rowWriters[n].write(col, bytes.array(), bytes.position(), len);
+ }
+ } else {
+ rows[n].setNullAt(col);
+ }
+ }
+ }
+
+ private void decodeFixedLenArrayAsDecimalBatch(int col, int num) throws IOException {
+ PrimitiveType type = requestedSchema.getFields().get(col).asPrimitiveType();
+ int precision = type.getDecimalMetadata().getPrecision();
+ int scale = type.getDecimalMetadata().getScale();
+ Preconditions.checkState(precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64(),
+ "Unsupported precision.");
+
+ for (int n = 0; n < num; ++n) {
+ if (columnReaders[col].next()) {
+ Binary v = columnReaders[col].nextBinary();
+ // Constructs a `Decimal` with an unscaled `Long` value if possible.
+ long unscaled = CatalystRowConverter.binaryToUnscaledLong(v);
+ rows[n].setDecimal(col, Decimal.apply(unscaled, precision, scale), precision);
+ } else {
+ rows[n].setNullAt(col);
+ }
+ }
+ }
+
+ /**
+ *
+ * Decoder to return values from a single column.
+ */
+ private static final class ColumnReader {
+ /**
+ * Total number of values read.
+ */
+ private long valuesRead;
+
+ /**
+ * value that indicates the end of the current page. That is,
+ * if valuesRead == endOfPageValueCount, we are at the end of the page.
+ */
+ private long endOfPageValueCount;
+
+ /**
+ * The dictionary, if this column has dictionary encoding.
+ */
+ private final Dictionary dictionary;
+
+ /**
+ * If true, the current page is dictionary encoded.
+ */
+ private boolean useDictionary;
+
+ /**
+ * Maximum definition level for this column.
+ */
+ private final int maxDefLevel;
+
+ /**
+ * Repetition/Definition/Value readers.
+ */
+ private IntIterator repetitionLevelColumn;
+ private IntIterator definitionLevelColumn;
+ private ValuesReader dataColumn;
+
+ /**
+ * Total number of values in this column (in this row group).
+ */
+ private final long totalValueCount;
+
+ /**
+ * Total values in the current page.
+ */
+ private int pageValueCount;
+
+ private final PageReader pageReader;
+ private final ColumnDescriptor descriptor;
+
+ public ColumnReader(ColumnDescriptor descriptor, PageReader pageReader)
+ throws IOException {
+ this.descriptor = descriptor;
+ this.pageReader = pageReader;
+ this.maxDefLevel = descriptor.getMaxDefinitionLevel();
+
+ DictionaryPage dictionaryPage = pageReader.readDictionaryPage();
+ if (dictionaryPage != null) {
+ try {
+ this.dictionary = dictionaryPage.getEncoding().initDictionary(descriptor, dictionaryPage);
+ this.useDictionary = true;
+ } catch (IOException e) {
+ throw new IOException("could not decode the dictionary for " + descriptor, e);
+ }
+ } else {
+ this.dictionary = null;
+ this.useDictionary = false;
+ }
+ this.totalValueCount = pageReader.getTotalValueCount();
+ if (totalValueCount == 0) {
+ throw new IOException("totalValueCount == 0");
+ }
+ }
+
+ /**
+ * TODO: Hoist the useDictionary branch to decode*Batch and make the batch page aligned.
+ */
+ public boolean nextBoolean() {
+ if (!useDictionary) {
+ return dataColumn.readBoolean();
+ } else {
+ return dictionary.decodeToBoolean(dataColumn.readValueDictionaryId());
+ }
+ }
+
+ public int nextInt() {
+ if (!useDictionary) {
+ return dataColumn.readInteger();
+ } else {
+ return dictionary.decodeToInt(dataColumn.readValueDictionaryId());
+ }
+ }
+
+ public long nextLong() {
+ if (!useDictionary) {
+ return dataColumn.readLong();
+ } else {
+ return dictionary.decodeToLong(dataColumn.readValueDictionaryId());
+ }
+ }
+
+ public float nextFloat() {
+ if (!useDictionary) {
+ return dataColumn.readFloat();
+ } else {
+ return dictionary.decodeToFloat(dataColumn.readValueDictionaryId());
+ }
+ }
+
+ public double nextDouble() {
+ if (!useDictionary) {
+ return dataColumn.readDouble();
+ } else {
+ return dictionary.decodeToDouble(dataColumn.readValueDictionaryId());
+ }
+ }
+
+ public Binary nextBinary() {
+ if (!useDictionary) {
+ return dataColumn.readBytes();
+ } else {
+ return dictionary.decodeToBinary(dataColumn.readValueDictionaryId());
+ }
+ }
+
+ /**
+ * Advances to the next value. Returns true if the value is non-null.
+ */
+ private boolean next() throws IOException {
+ if (valuesRead >= endOfPageValueCount) {
+ if (valuesRead >= totalValueCount) {
+ // How do we get here? Throw end of stream exception?
+ return false;
+ }
+ readPage();
+ }
+ ++valuesRead;
+ // TODO: Don't read for flat schemas
+ //repetitionLevel = repetitionLevelColumn.nextInt();
+ return definitionLevelColumn.nextInt() == maxDefLevel;
+ }
+
+ private void readPage() throws IOException {
+ DataPage page = pageReader.readPage();
+ // TODO: Why is this a visitor?
+ page.accept(new DataPage.Visitor<Void>() {
+ @Override
+ public Void visit(DataPageV1 dataPageV1) {
+ try {
+ readPageV1(dataPageV1);
+ return null;
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public Void visit(DataPageV2 dataPageV2) {
+ try {
+ readPageV2(dataPageV2);
+ return null;
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ });
+ }
+
+ private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset, int valueCount)
+ throws IOException {
+ this.pageValueCount = valueCount;
+ this.endOfPageValueCount = valuesRead + pageValueCount;
+ if (dataEncoding.usesDictionary()) {
+ if (dictionary == null) {
+ throw new IOException(
+ "could not read page in col " + descriptor +
+ " as the dictionary was missing for encoding " + dataEncoding);
+ }
+ this.dataColumn = dataEncoding.getDictionaryBasedValuesReader(
+ descriptor, VALUES, dictionary);
+ this.useDictionary = true;
+ } else {
+ this.dataColumn = dataEncoding.getValuesReader(descriptor, VALUES);
+ this.useDictionary = false;
+ }
+
+ try {
+ dataColumn.initFromPage(pageValueCount, bytes, offset);
+ } catch (IOException e) {
+ throw new IOException("could not read page in col " + descriptor, e);
+ }
+ }
+
+ private void readPageV1(DataPageV1 page) throws IOException {
+ ValuesReader rlReader = page.getRlEncoding().getValuesReader(descriptor, REPETITION_LEVEL);
+ ValuesReader dlReader = page.getDlEncoding().getValuesReader(descriptor, DEFINITION_LEVEL);
+ this.repetitionLevelColumn = new ValuesReaderIntIterator(rlReader);
+ this.definitionLevelColumn = new ValuesReaderIntIterator(dlReader);
+ try {
+ byte[] bytes = page.getBytes().toByteArray();
+ rlReader.initFromPage(pageValueCount, bytes, 0);
+ int next = rlReader.getNextOffset();
+ dlReader.initFromPage(pageValueCount, bytes, next);
+ next = dlReader.getNextOffset();
+ initDataReader(page.getValueEncoding(), bytes, next, page.getValueCount());
+ } catch (IOException e) {
+ throw new IOException("could not read page " + page + " in col " + descriptor, e);
+ }
+ }
+
+ private void readPageV2(DataPageV2 page) throws IOException {
+ this.repetitionLevelColumn = createRLEIterator(descriptor.getMaxRepetitionLevel(),
+ page.getRepetitionLevels(), descriptor);
+ this.definitionLevelColumn = createRLEIterator(descriptor.getMaxDefinitionLevel(),
+ page.getDefinitionLevels(), descriptor);
+ try {
+ initDataReader(page.getDataEncoding(), page.getData().toByteArray(), 0,
+ page.getValueCount());
+ } catch (IOException e) {
+ throw new IOException("could not read page " + page + " in col " + descriptor, e);
+ }
+ }
+ }
+
+ private void checkEndOfRowGroup() throws IOException {
+ if (rowsReturned != totalCountLoadedSoFar) return;
+ PageReadStore pages = reader.readNextRowGroup();
+ if (pages == null) {
+ throw new IOException("expecting more rows but reached last block. Read "
+ + rowsReturned + " out of " + totalRowCount);
+ }
+ List<ColumnDescriptor> columns = requestedSchema.getColumns();
+ columnReaders = new ColumnReader[columns.size()];
+ for (int i = 0; i < columns.size(); ++i) {
+ columnReaders[i] = new ColumnReader(columns.get(i), pages.getPageReader(columns.get(i)));
+ }
+ totalCountLoadedSoFar += pages.getRowCount();
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala
index 1f653cd3d3..94298fae2d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala
@@ -370,35 +370,13 @@ private[parquet] class CatalystRowConverter(
protected def decimalFromBinary(value: Binary): Decimal = {
if (precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64) {
// Constructs a `Decimal` with an unscaled `Long` value if possible.
- val unscaled = binaryToUnscaledLong(value)
+ val unscaled = CatalystRowConverter.binaryToUnscaledLong(value)
Decimal(unscaled, precision, scale)
} else {
// Otherwise, resorts to an unscaled `BigInteger` instead.
Decimal(new BigDecimal(new BigInteger(value.getBytes), scale), precision, scale)
}
}
-
- private def binaryToUnscaledLong(binary: Binary): Long = {
- // The underlying `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here
- // we are using `Binary.toByteBuffer.array()` to steal the underlying byte array without
- // copying it.
- val buffer = binary.toByteBuffer
- val bytes = buffer.array()
- val start = buffer.position()
- val end = buffer.limit()
-
- var unscaled = 0L
- var i = start
-
- while (i < end) {
- unscaled = (unscaled << 8) | (bytes(i) & 0xff)
- i += 1
- }
-
- val bits = 8 * (end - start)
- unscaled = (unscaled << (64 - bits)) >> (64 - bits)
- unscaled
- }
}
private class CatalystIntDictionaryAwareDecimalConverter(
@@ -658,3 +636,27 @@ private[parquet] class CatalystRowConverter(
override def start(): Unit = elementConverter.start()
}
}
+
+private[parquet] object CatalystRowConverter {
+ def binaryToUnscaledLong(binary: Binary): Long = {
+ // The underlying `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here
+ // we are using `Binary.toByteBuffer.array()` to steal the underlying byte array without
+ // copying it.
+ val buffer = binary.toByteBuffer
+ val bytes = buffer.array()
+ val start = buffer.position()
+ val end = buffer.limit()
+
+ var unscaled = 0L
+ var i = start
+
+ while (i < end) {
+ unscaled = (unscaled << 8) | (bytes(i) & 0xff)
+ i += 1
+ }
+
+ val bits = 8 * (end - start)
+ unscaled = (unscaled << (64 - bits)) >> (64 - bits)
+ unscaled
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
index 458786f77a..c8028a5ef5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
@@ -337,7 +337,9 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
}
}
- test("SPARK-11661 Still pushdown filters returned by unhandledFilters") {
+ // Renable when we can toggle custom ParquetRecordReader on/off. The custom reader does
+ // not do row by row filtering (and we probably don't want to push that).
+ ignore("SPARK-11661 Still pushdown filters returned by unhandledFilters") {
import testImplicits._
withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") {
withTempPath { dir =>