diff options
8 files changed, 944 insertions, 43 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala index 264dae7f39..4d176332b6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala @@ -20,8 +20,6 @@ package org.apache.spark.rdd import java.text.SimpleDateFormat import java.util.Date -import scala.reflect.ClassTag - import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ @@ -30,10 +28,12 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.storage.StorageLevel import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.{Utils, SerializableConfiguration, ShutdownHookManager} import org.apache.spark.{Partition => SparkPartition, _} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Utils} + +import scala.reflect.ClassTag private[spark] class SqlNewHadoopPartition( @@ -96,6 +96,11 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( @transient protected val jobId = new JobID(jobTrackerId, id) + // If true, enable using the custom RecordReader for parquet. This only works for + // a subset of the types (no complex types). + protected val enableUnsafeRowParquetReader: Boolean = + sc.conf.getBoolean("spark.parquet.enableUnsafeRowRecordReader", true) + override def getPartitions: Array[SparkPartition] = { val conf = getConf(isDriverSide = true) val inputFormat = inputFormatClass.newInstance @@ -150,9 +155,31 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( configurable.setConf(conf) case _ => } - private[this] var reader = format.createRecordReader( - split.serializableHadoopSplit.value, hadoopAttemptContext) - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + private[this] var reader: RecordReader[Void, V] = null + + /** + * If the format is ParquetInputFormat, try to create the optimized RecordReader. If this + * fails (for example, unsupported schema), try with the normal reader. + * TODO: plumb this through a different way? + */ + if (enableUnsafeRowParquetReader && + format.getClass.getName == "org.apache.parquet.hadoop.ParquetInputFormat") { + // TODO: move this class to sql.execution and remove this. + reader = Utils.classForName( + "org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader") + .newInstance().asInstanceOf[RecordReader[Void, V]] + try { + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + } catch { + case e: Exception => reader = null + } + } + + if (reader == null) { + reader = format.createRecordReader( + split.serializableHadoopSplit.value, hadoopAttemptContext) + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + } // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener(context => close()) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 5ba14ebdb6..33769363a0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -178,6 +178,15 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS pointTo(buf, Platform.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); } + /** + * Updates this UnsafeRow preserving the number of fields. + * @param buf byte array to point to + * @param sizeInBytes the number of bytes valid in the byte array + */ + public void pointTo(byte[] buf, int sizeInBytes) { + pointTo(buf, numFields, sizeInBytes); + } + @Override public void setNullAt(int i) { assertIndexIsValid(i); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 9c94686780..d26b1b187c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -17,19 +17,28 @@ package org.apache.spark.sql.catalyst.expressions.codegen; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.unsafe.Platform; /** - * A helper class to manage the row buffer used in `GenerateUnsafeProjection`. - * - * Note that it is only used in `GenerateUnsafeProjection`, so it's safe to mark member variables - * public for ease of use. + * A helper class to manage the row buffer when construct unsafe rows. */ public class BufferHolder { - public byte[] buffer = new byte[64]; + public byte[] buffer; public int cursor = Platform.BYTE_ARRAY_OFFSET; - public void grow(int neededSize) { + public BufferHolder() { + this(64); + } + + public BufferHolder(int size) { + buffer = new byte[size]; + } + + /** + * Grows the buffer to at least neededSize. If row is non-null, points the row to the buffer. + */ + public void grow(int neededSize, UnsafeRow row) { final int length = totalSize() + neededSize; if (buffer.length < length) { // This will not happen frequently, because the buffer is re-used. @@ -41,12 +50,23 @@ public class BufferHolder { Platform.BYTE_ARRAY_OFFSET, totalSize()); buffer = tmp; + if (row != null) { + row.pointTo(buffer, length * 2); + } } } + public void grow(int neededSize) { + grow(neededSize, null); + } + public void reset() { cursor = Platform.BYTE_ARRAY_OFFSET; } + public void resetTo(int offset) { + assert(offset <= buffer.length); + cursor = Platform.BYTE_ARRAY_OFFSET + offset; + } public int totalSize() { return cursor - Platform.BYTE_ARRAY_OFFSET; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 048b7749d8..e227c0dec9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -35,6 +35,7 @@ public class UnsafeRowWriter { // The offset of the global buffer where we start to write this row. private int startingOffset; private int nullBitsSize; + private UnsafeRow row; public void initialize(BufferHolder holder, int numFields) { this.holder = holder; @@ -43,7 +44,7 @@ public class UnsafeRowWriter { // grow the global buffer to make sure it has enough space to write fixed-length data. final int fixedSize = nullBitsSize + 8 * numFields; - holder.grow(fixedSize); + holder.grow(fixedSize, row); holder.cursor += fixedSize; // zero-out the null bits region @@ -52,12 +53,19 @@ public class UnsafeRowWriter { } } + public void initialize(UnsafeRow row, BufferHolder holder, int numFields) { + initialize(holder, numFields); + this.row = row; + } + private void zeroOutPaddingBytes(int numBytes) { if ((numBytes & 0x07) > 0) { Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); } } + public BufferHolder holder() { return holder; } + public boolean isNullAt(int ordinal) { return BitSetMethods.isSet(holder.buffer, startingOffset, ordinal); } @@ -90,7 +98,7 @@ public class UnsafeRowWriter { if (remainder > 0) { final int paddingBytes = 8 - remainder; - holder.grow(paddingBytes); + holder.grow(paddingBytes, row); for (int i = 0; i < paddingBytes; i++) { Platform.putByte(holder.buffer, holder.cursor, (byte) 0); @@ -153,7 +161,7 @@ public class UnsafeRowWriter { } } else { // grow the global buffer before writing data. - holder.grow(16); + holder.grow(16, row); // zero-out the bytes Platform.putLong(holder.buffer, holder.cursor, 0L); @@ -185,7 +193,7 @@ public class UnsafeRowWriter { final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); // grow the global buffer before writing data. - holder.grow(roundedSize); + holder.grow(roundedSize, row); zeroOutPaddingBytes(numBytes); @@ -206,7 +214,7 @@ public class UnsafeRowWriter { final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); // grow the global buffer before writing data. - holder.grow(roundedSize); + holder.grow(roundedSize, row); zeroOutPaddingBytes(numBytes); @@ -222,7 +230,7 @@ public class UnsafeRowWriter { public void write(int ordinal, CalendarInterval input) { // grow the global buffer before writing data. - holder.grow(16); + holder.grow(16, row); // Write the months and microseconds fields of Interval to the variable length portion. Platform.putLong(holder.buffer, holder.cursor, input.months); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java new file mode 100644 index 0000000000..2ed30c1f5a --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.sql.execution.datasources.parquet; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.apache.parquet.filter2.compat.RowGroupFilter.filterRowGroups; +import static org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER; +import static org.apache.parquet.format.converter.ParquetMetadataConverter.range; +import static org.apache.parquet.hadoop.ParquetFileReader.readFooter; +import static org.apache.parquet.hadoop.ParquetInputFormat.getFilter; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.RecordReader; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.parquet.bytes.BytesInput; +import org.apache.parquet.bytes.BytesUtils; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.column.values.ValuesReader; +import org.apache.parquet.column.values.rle.RunLengthBitPackingHybridDecoder; +import org.apache.parquet.filter2.compat.FilterCompat; +import org.apache.parquet.hadoop.BadConfigurationException; +import org.apache.parquet.hadoop.ParquetFileReader; +import org.apache.parquet.hadoop.ParquetInputFormat; +import org.apache.parquet.hadoop.ParquetInputSplit; +import org.apache.parquet.hadoop.api.InitContext; +import org.apache.parquet.hadoop.api.ReadSupport; +import org.apache.parquet.hadoop.metadata.BlockMetaData; +import org.apache.parquet.hadoop.metadata.ParquetMetadata; +import org.apache.parquet.hadoop.util.ConfigurationUtil; +import org.apache.parquet.schema.MessageType; + +/** + * Base class for custom RecordReaaders for Parquet that directly materialize to `T`. + * This class handles computing row groups, filtering on them, setting up the column readers, + * etc. + * This is heavily based on parquet-mr's RecordReader. + * TODO: move this to the parquet-mr project. There are performance benefits of doing it + * this way, albeit at a higher cost to implement. This base class is reusable. + */ +public abstract class SpecificParquetRecordReaderBase<T> extends RecordReader<Void, T> { + protected Path file; + protected MessageType fileSchema; + protected MessageType requestedSchema; + protected ReadSupport<T> readSupport; + + /** + * The total number of rows this RecordReader will eventually read. The sum of the + * rows of all the row groups. + */ + protected long totalRowCount; + + protected ParquetFileReader reader; + + public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) + throws IOException, InterruptedException { + Configuration configuration = taskAttemptContext.getConfiguration(); + ParquetInputSplit split = (ParquetInputSplit)inputSplit; + this.file = split.getPath(); + long[] rowGroupOffsets = split.getRowGroupOffsets(); + + ParquetMetadata footer; + List<BlockMetaData> blocks; + + // if task.side.metadata is set, rowGroupOffsets is null + if (rowGroupOffsets == null) { + // then we need to apply the predicate push down filter + footer = readFooter(configuration, file, range(split.getStart(), split.getEnd())); + MessageType fileSchema = footer.getFileMetaData().getSchema(); + FilterCompat.Filter filter = getFilter(configuration); + blocks = filterRowGroups(filter, footer.getBlocks(), fileSchema); + } else { + // otherwise we find the row groups that were selected on the client + footer = readFooter(configuration, file, NO_FILTER); + Set<Long> offsets = new HashSet<>(); + for (long offset : rowGroupOffsets) { + offsets.add(offset); + } + blocks = new ArrayList<>(); + for (BlockMetaData block : footer.getBlocks()) { + if (offsets.contains(block.getStartingPos())) { + blocks.add(block); + } + } + // verify we found them all + if (blocks.size() != rowGroupOffsets.length) { + long[] foundRowGroupOffsets = new long[footer.getBlocks().size()]; + for (int i = 0; i < foundRowGroupOffsets.length; i++) { + foundRowGroupOffsets[i] = footer.getBlocks().get(i).getStartingPos(); + } + // this should never happen. + // provide a good error message in case there's a bug + throw new IllegalStateException( + "All the offsets listed in the split should be found in the file." + + " expected: " + Arrays.toString(rowGroupOffsets) + + " found: " + blocks + + " out of: " + Arrays.toString(foundRowGroupOffsets) + + " in range " + split.getStart() + ", " + split.getEnd()); + } + } + MessageType fileSchema = footer.getFileMetaData().getSchema(); + Map<String, String> fileMetadata = footer.getFileMetaData().getKeyValueMetaData(); + this.readSupport = getReadSupportInstance( + (Class<? extends ReadSupport<T>>) getReadSupportClass(configuration)); + ReadSupport.ReadContext readContext = readSupport.init(new InitContext( + taskAttemptContext.getConfiguration(), toSetMultiMap(fileMetadata), fileSchema)); + this.requestedSchema = readContext.getRequestedSchema(); + this.fileSchema = fileSchema; + this.reader = new ParquetFileReader(configuration, file, blocks, requestedSchema.getColumns()); + for (BlockMetaData block : blocks) { + this.totalRowCount += block.getRowCount(); + } + } + + @Override + public Void getCurrentKey() throws IOException, InterruptedException { + return null; + } + + @Override + public void close() throws IOException { + if (reader != null) { + reader.close(); + reader = null; + } + } + + /** + * Utility classes to abstract over different way to read ints with different encodings. + * TODO: remove this layer of abstraction? + */ + abstract static class IntIterator { + abstract int nextInt() throws IOException; + } + + protected static final class ValuesReaderIntIterator extends IntIterator { + ValuesReader delegate; + + public ValuesReaderIntIterator(ValuesReader delegate) { + this.delegate = delegate; + } + + @Override + int nextInt() throws IOException { + return delegate.readInteger(); + } + } + + protected static final class RLEIntIterator extends IntIterator { + RunLengthBitPackingHybridDecoder delegate; + + public RLEIntIterator(RunLengthBitPackingHybridDecoder delegate) { + this.delegate = delegate; + } + + @Override + int nextInt() throws IOException { + return delegate.readInt(); + } + } + + protected static final class NullIntIterator extends IntIterator { + @Override + int nextInt() throws IOException { return 0; } + } + + /** + * Creates a reader for definition and repetition levels, returning an optimized one if + * the levels are not needed. + */ + static protected IntIterator createRLEIterator(int maxLevel, BytesInput bytes, + ColumnDescriptor descriptor) throws IOException { + try { + if (maxLevel == 0) return new NullIntIterator(); + return new RLEIntIterator( + new RunLengthBitPackingHybridDecoder( + BytesUtils.getWidthFromMaxInt(maxLevel), + new ByteArrayInputStream(bytes.toByteArray()))); + } catch (IOException e) { + throw new IOException("could not read levels in page for col " + descriptor, e); + } + } + + private static <K, V> Map<K, Set<V>> toSetMultiMap(Map<K, V> map) { + Map<K, Set<V>> setMultiMap = new HashMap<>(); + for (Map.Entry<K, V> entry : map.entrySet()) { + Set<V> set = new HashSet<>(); + set.add(entry.getValue()); + setMultiMap.put(entry.getKey(), Collections.unmodifiableSet(set)); + } + return Collections.unmodifiableMap(setMultiMap); + } + + private static Class<?> getReadSupportClass(Configuration configuration) { + return ConfigurationUtil.getClassFromConfig(configuration, + ParquetInputFormat.READ_SUPPORT_CLASS, ReadSupport.class); + } + + /** + * @param readSupportClass to instantiate + * @return the configured read support + */ + private static <T> ReadSupport<T> getReadSupportInstance( + Class<? extends ReadSupport<T>> readSupportClass){ + try { + return readSupportClass.newInstance(); + } catch (InstantiationException e) { + throw new BadConfigurationException("could not instantiate read support class", e); + } catch (IllegalAccessException e) { + throw new BadConfigurationException("could not instantiate read support class", e); + } + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java new file mode 100644 index 0000000000..8a92e489cc --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -0,0 +1,593 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.types.UTF8String; + +import static org.apache.parquet.column.ValuesType.DEFINITION_LEVEL; +import static org.apache.parquet.column.ValuesType.REPETITION_LEVEL; +import static org.apache.parquet.column.ValuesType.VALUES; + +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.parquet.Preconditions; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.column.Dictionary; +import org.apache.parquet.column.Encoding; +import org.apache.parquet.column.page.DataPage; +import org.apache.parquet.column.page.DataPageV1; +import org.apache.parquet.column.page.DataPageV2; +import org.apache.parquet.column.page.DictionaryPage; +import org.apache.parquet.column.page.PageReadStore; +import org.apache.parquet.column.page.PageReader; +import org.apache.parquet.column.values.ValuesReader; +import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.OriginalType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; + +/** + * A specialized RecordReader that reads into UnsafeRows directly using the Parquet column APIs. + * + * This is somewhat based on parquet-mr's ColumnReader. + * + * TODO: handle complex types, decimal requiring more than 8 bytes, INT96. Schema mismatch. + * All of these can be handled efficiently and easily with codegen. + */ +public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBase<UnsafeRow> { + /** + * Batch of unsafe rows that we assemble and the current index we've returned. Everytime this + * batch is used up (batchIdx == numBatched), we populated the batch. + */ + private UnsafeRow[] rows = new UnsafeRow[64]; + private int batchIdx = 0; + private int numBatched = 0; + + /** + * Used to write variable length columns. Same length as `rows`. + */ + private UnsafeRowWriter[] rowWriters = null; + /** + * True if the row contains variable length fields. + */ + private boolean containsVarLenFields; + + /** + * The number of bytes in the fixed length portion of the row. + */ + private int fixedSizeBytes; + + /** + * For each request column, the reader to read this column. + * columnsReaders[i] populated the UnsafeRow's attribute at i. + */ + private ColumnReader[] columnReaders; + + /** + * The number of rows that have been returned. + */ + private long rowsReturned; + + /** + * The number of rows that have been reading, including the current in flight row group. + */ + private long totalCountLoadedSoFar = 0; + + /** + * For each column, the annotated original type. + */ + private OriginalType[] originalTypes; + + /** + * The default size for varlen columns. The row grows as necessary to accommodate the + * largest column. + */ + private static final int DEFAULT_VAR_LEN_SIZE = 32; + + /** + * Implementation of RecordReader API. + */ + @Override + public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) + throws IOException, InterruptedException { + super.initialize(inputSplit, taskAttemptContext); + + /** + * Check that the requested schema is supported. + */ + if (requestedSchema.getFieldCount() == 0) { + // TODO: what does this mean? + throw new IOException("Empty request schema not supported."); + } + int numVarLenFields = 0; + originalTypes = new OriginalType[requestedSchema.getFieldCount()]; + for (int i = 0; i < requestedSchema.getFieldCount(); ++i) { + Type t = requestedSchema.getFields().get(i); + if (!t.isPrimitive() || t.isRepetition(Type.Repetition.REPEATED)) { + throw new IOException("Complex types not supported."); + } + PrimitiveType primitiveType = t.asPrimitiveType(); + + originalTypes[i] = t.getOriginalType(); + + // TODO: Be extremely cautious in what is supported. Expand this. + if (originalTypes[i] != null && originalTypes[i] != OriginalType.DECIMAL && + originalTypes[i] != OriginalType.UTF8 && originalTypes[i] != OriginalType.DATE) { + throw new IOException("Unsupported type: " + t); + } + if (originalTypes[i] == OriginalType.DECIMAL && + primitiveType.getDecimalMetadata().getPrecision() > + CatalystSchemaConverter.MAX_PRECISION_FOR_INT64()) { + throw new IOException("Decimal with high precision is not supported."); + } + if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT96) { + throw new IOException("Int96 not supported."); + } + ColumnDescriptor fd = fileSchema.getColumnDescription(requestedSchema.getPaths().get(i)); + if (!fd.equals(requestedSchema.getColumns().get(i))) { + throw new IOException("Schema evolution not supported."); + } + + if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.BINARY) { + ++numVarLenFields; + } + } + + /** + * Initialize rows and rowWriters. These objects are reused across all rows in the relation. + */ + int rowByteSize = UnsafeRow.calculateBitSetWidthInBytes(requestedSchema.getFieldCount()); + rowByteSize += 8 * requestedSchema.getFieldCount(); + fixedSizeBytes = rowByteSize; + rowByteSize += numVarLenFields * DEFAULT_VAR_LEN_SIZE; + containsVarLenFields = numVarLenFields > 0; + rowWriters = new UnsafeRowWriter[rows.length]; + + for (int i = 0; i < rows.length; ++i) { + rows[i] = new UnsafeRow(); + rowWriters[i] = new UnsafeRowWriter(); + BufferHolder holder = new BufferHolder(rowByteSize); + rowWriters[i].initialize(rows[i], holder, requestedSchema.getFieldCount()); + rows[i].pointTo(holder.buffer, Platform.BYTE_ARRAY_OFFSET, requestedSchema.getFieldCount(), + holder.buffer.length); + } + } + + @Override + public boolean nextKeyValue() throws IOException, InterruptedException { + if (batchIdx >= numBatched) { + if (!loadBatch()) return false; + } + ++batchIdx; + return true; + } + + @Override + public UnsafeRow getCurrentValue() throws IOException, InterruptedException { + return rows[batchIdx - 1]; + } + + @Override + public float getProgress() throws IOException, InterruptedException { + return (float) rowsReturned / totalRowCount; + } + + /** + * Decodes a batch of values into `rows`. This function is the hot path. + */ + private boolean loadBatch() throws IOException { + // no more records left + if (rowsReturned >= totalRowCount) { return false; } + checkEndOfRowGroup(); + + int num = (int)Math.min(rows.length, totalCountLoadedSoFar - rowsReturned); + rowsReturned += num; + + if (containsVarLenFields) { + for (int i = 0; i < rowWriters.length; ++i) { + rowWriters[i].holder().resetTo(fixedSizeBytes); + } + } + + for (int i = 0; i < columnReaders.length; ++i) { + switch (columnReaders[i].descriptor.getType()) { + case BOOLEAN: + decodeBooleanBatch(i, num); + break; + case INT32: + if (originalTypes[i] == OriginalType.DECIMAL) { + decodeIntAsDecimalBatch(i, num); + } else { + decodeIntBatch(i, num); + } + break; + case INT64: + Preconditions.checkState(originalTypes[i] == null + || originalTypes[i] == OriginalType.DECIMAL, + "Unexpected original type: " + originalTypes[i]); + decodeLongBatch(i, num); + break; + case FLOAT: + decodeFloatBatch(i, num); + break; + case DOUBLE: + decodeDoubleBatch(i, num); + break; + case BINARY: + decodeBinaryBatch(i, num); + break; + case FIXED_LEN_BYTE_ARRAY: + Preconditions.checkState(originalTypes[i] == OriginalType.DECIMAL, + "Unexpected original type: " + originalTypes[i]); + decodeFixedLenArrayAsDecimalBatch(i, num); + break; + case INT96: + throw new IOException("Unsupported " + columnReaders[i].descriptor.getType()); + } + numBatched = num; + batchIdx = 0; + } + return true; + } + + private void decodeBooleanBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setBoolean(col, columnReaders[col].nextBoolean()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeIntBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setInt(col, columnReaders[col].nextInt()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeIntAsDecimalBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + // Since this is stored as an INT, it is always a compact decimal. Just set it as a long. + rows[n].setLong(col, columnReaders[col].nextInt()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeLongBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setLong(col, columnReaders[col].nextLong()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeFloatBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setFloat(col, columnReaders[col].nextFloat()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeDoubleBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setDouble(col, columnReaders[col].nextDouble()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeBinaryBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + ByteBuffer bytes = columnReaders[col].nextBinary().toByteBuffer(); + int len = bytes.limit() - bytes.position(); + if (originalTypes[col] == OriginalType.UTF8) { + UTF8String str = UTF8String.fromBytes(bytes.array(), bytes.position(), len); + rowWriters[n].write(col, str); + } else { + rowWriters[n].write(col, bytes.array(), bytes.position(), len); + } + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeFixedLenArrayAsDecimalBatch(int col, int num) throws IOException { + PrimitiveType type = requestedSchema.getFields().get(col).asPrimitiveType(); + int precision = type.getDecimalMetadata().getPrecision(); + int scale = type.getDecimalMetadata().getScale(); + Preconditions.checkState(precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64(), + "Unsupported precision."); + + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + Binary v = columnReaders[col].nextBinary(); + // Constructs a `Decimal` with an unscaled `Long` value if possible. + long unscaled = CatalystRowConverter.binaryToUnscaledLong(v); + rows[n].setDecimal(col, Decimal.apply(unscaled, precision, scale), precision); + } else { + rows[n].setNullAt(col); + } + } + } + + /** + * + * Decoder to return values from a single column. + */ + private static final class ColumnReader { + /** + * Total number of values read. + */ + private long valuesRead; + + /** + * value that indicates the end of the current page. That is, + * if valuesRead == endOfPageValueCount, we are at the end of the page. + */ + private long endOfPageValueCount; + + /** + * The dictionary, if this column has dictionary encoding. + */ + private final Dictionary dictionary; + + /** + * If true, the current page is dictionary encoded. + */ + private boolean useDictionary; + + /** + * Maximum definition level for this column. + */ + private final int maxDefLevel; + + /** + * Repetition/Definition/Value readers. + */ + private IntIterator repetitionLevelColumn; + private IntIterator definitionLevelColumn; + private ValuesReader dataColumn; + + /** + * Total number of values in this column (in this row group). + */ + private final long totalValueCount; + + /** + * Total values in the current page. + */ + private int pageValueCount; + + private final PageReader pageReader; + private final ColumnDescriptor descriptor; + + public ColumnReader(ColumnDescriptor descriptor, PageReader pageReader) + throws IOException { + this.descriptor = descriptor; + this.pageReader = pageReader; + this.maxDefLevel = descriptor.getMaxDefinitionLevel(); + + DictionaryPage dictionaryPage = pageReader.readDictionaryPage(); + if (dictionaryPage != null) { + try { + this.dictionary = dictionaryPage.getEncoding().initDictionary(descriptor, dictionaryPage); + this.useDictionary = true; + } catch (IOException e) { + throw new IOException("could not decode the dictionary for " + descriptor, e); + } + } else { + this.dictionary = null; + this.useDictionary = false; + } + this.totalValueCount = pageReader.getTotalValueCount(); + if (totalValueCount == 0) { + throw new IOException("totalValueCount == 0"); + } + } + + /** + * TODO: Hoist the useDictionary branch to decode*Batch and make the batch page aligned. + */ + public boolean nextBoolean() { + if (!useDictionary) { + return dataColumn.readBoolean(); + } else { + return dictionary.decodeToBoolean(dataColumn.readValueDictionaryId()); + } + } + + public int nextInt() { + if (!useDictionary) { + return dataColumn.readInteger(); + } else { + return dictionary.decodeToInt(dataColumn.readValueDictionaryId()); + } + } + + public long nextLong() { + if (!useDictionary) { + return dataColumn.readLong(); + } else { + return dictionary.decodeToLong(dataColumn.readValueDictionaryId()); + } + } + + public float nextFloat() { + if (!useDictionary) { + return dataColumn.readFloat(); + } else { + return dictionary.decodeToFloat(dataColumn.readValueDictionaryId()); + } + } + + public double nextDouble() { + if (!useDictionary) { + return dataColumn.readDouble(); + } else { + return dictionary.decodeToDouble(dataColumn.readValueDictionaryId()); + } + } + + public Binary nextBinary() { + if (!useDictionary) { + return dataColumn.readBytes(); + } else { + return dictionary.decodeToBinary(dataColumn.readValueDictionaryId()); + } + } + + /** + * Advances to the next value. Returns true if the value is non-null. + */ + private boolean next() throws IOException { + if (valuesRead >= endOfPageValueCount) { + if (valuesRead >= totalValueCount) { + // How do we get here? Throw end of stream exception? + return false; + } + readPage(); + } + ++valuesRead; + // TODO: Don't read for flat schemas + //repetitionLevel = repetitionLevelColumn.nextInt(); + return definitionLevelColumn.nextInt() == maxDefLevel; + } + + private void readPage() throws IOException { + DataPage page = pageReader.readPage(); + // TODO: Why is this a visitor? + page.accept(new DataPage.Visitor<Void>() { + @Override + public Void visit(DataPageV1 dataPageV1) { + try { + readPageV1(dataPageV1); + return null; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public Void visit(DataPageV2 dataPageV2) { + try { + readPageV2(dataPageV2); + return null; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + }); + } + + private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset, int valueCount) + throws IOException { + this.pageValueCount = valueCount; + this.endOfPageValueCount = valuesRead + pageValueCount; + if (dataEncoding.usesDictionary()) { + if (dictionary == null) { + throw new IOException( + "could not read page in col " + descriptor + + " as the dictionary was missing for encoding " + dataEncoding); + } + this.dataColumn = dataEncoding.getDictionaryBasedValuesReader( + descriptor, VALUES, dictionary); + this.useDictionary = true; + } else { + this.dataColumn = dataEncoding.getValuesReader(descriptor, VALUES); + this.useDictionary = false; + } + + try { + dataColumn.initFromPage(pageValueCount, bytes, offset); + } catch (IOException e) { + throw new IOException("could not read page in col " + descriptor, e); + } + } + + private void readPageV1(DataPageV1 page) throws IOException { + ValuesReader rlReader = page.getRlEncoding().getValuesReader(descriptor, REPETITION_LEVEL); + ValuesReader dlReader = page.getDlEncoding().getValuesReader(descriptor, DEFINITION_LEVEL); + this.repetitionLevelColumn = new ValuesReaderIntIterator(rlReader); + this.definitionLevelColumn = new ValuesReaderIntIterator(dlReader); + try { + byte[] bytes = page.getBytes().toByteArray(); + rlReader.initFromPage(pageValueCount, bytes, 0); + int next = rlReader.getNextOffset(); + dlReader.initFromPage(pageValueCount, bytes, next); + next = dlReader.getNextOffset(); + initDataReader(page.getValueEncoding(), bytes, next, page.getValueCount()); + } catch (IOException e) { + throw new IOException("could not read page " + page + " in col " + descriptor, e); + } + } + + private void readPageV2(DataPageV2 page) throws IOException { + this.repetitionLevelColumn = createRLEIterator(descriptor.getMaxRepetitionLevel(), + page.getRepetitionLevels(), descriptor); + this.definitionLevelColumn = createRLEIterator(descriptor.getMaxDefinitionLevel(), + page.getDefinitionLevels(), descriptor); + try { + initDataReader(page.getDataEncoding(), page.getData().toByteArray(), 0, + page.getValueCount()); + } catch (IOException e) { + throw new IOException("could not read page " + page + " in col " + descriptor, e); + } + } + } + + private void checkEndOfRowGroup() throws IOException { + if (rowsReturned != totalCountLoadedSoFar) return; + PageReadStore pages = reader.readNextRowGroup(); + if (pages == null) { + throw new IOException("expecting more rows but reached last block. Read " + + rowsReturned + " out of " + totalRowCount); + } + List<ColumnDescriptor> columns = requestedSchema.getColumns(); + columnReaders = new ColumnReader[columns.size()]; + for (int i = 0; i < columns.size(); ++i) { + columnReaders[i] = new ColumnReader(columns.get(i), pages.getPageReader(columns.get(i))); + } + totalCountLoadedSoFar += pages.getRowCount(); + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 1f653cd3d3..94298fae2d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -370,35 +370,13 @@ private[parquet] class CatalystRowConverter( protected def decimalFromBinary(value: Binary): Decimal = { if (precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64) { // Constructs a `Decimal` with an unscaled `Long` value if possible. - val unscaled = binaryToUnscaledLong(value) + val unscaled = CatalystRowConverter.binaryToUnscaledLong(value) Decimal(unscaled, precision, scale) } else { // Otherwise, resorts to an unscaled `BigInteger` instead. Decimal(new BigDecimal(new BigInteger(value.getBytes), scale), precision, scale) } } - - private def binaryToUnscaledLong(binary: Binary): Long = { - // The underlying `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here - // we are using `Binary.toByteBuffer.array()` to steal the underlying byte array without - // copying it. - val buffer = binary.toByteBuffer - val bytes = buffer.array() - val start = buffer.position() - val end = buffer.limit() - - var unscaled = 0L - var i = start - - while (i < end) { - unscaled = (unscaled << 8) | (bytes(i) & 0xff) - i += 1 - } - - val bits = 8 * (end - start) - unscaled = (unscaled << (64 - bits)) >> (64 - bits) - unscaled - } } private class CatalystIntDictionaryAwareDecimalConverter( @@ -658,3 +636,27 @@ private[parquet] class CatalystRowConverter( override def start(): Unit = elementConverter.start() } } + +private[parquet] object CatalystRowConverter { + def binaryToUnscaledLong(binary: Binary): Long = { + // The underlying `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here + // we are using `Binary.toByteBuffer.array()` to steal the underlying byte array without + // copying it. + val buffer = binary.toByteBuffer + val bytes = buffer.array() + val start = buffer.position() + val end = buffer.limit() + + var unscaled = 0L + var i = start + + while (i < end) { + unscaled = (unscaled << 8) | (bytes(i) & 0xff) + i += 1 + } + + val bits = 8 * (end - start) + unscaled = (unscaled << (64 - bits)) >> (64 - bits) + unscaled + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 458786f77a..c8028a5ef5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -337,7 +337,9 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } - test("SPARK-11661 Still pushdown filters returned by unhandledFilters") { + // Renable when we can toggle custom ParquetRecordReader on/off. The custom reader does + // not do row by row filtering (and we probably don't want to push that). + ignore("SPARK-11661 Still pushdown filters returned by unhandledFilters") { import testImplicits._ withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { withTempPath { dir => |