diff options
3 files changed, 146 insertions, 5 deletions
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 09c001baae..c462ab1a13 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -16,9 +16,7 @@ */ package org.apache.spark.sql.execution.vectorized; -import java.util.Arrays; -import java.util.Iterator; -import java.util.NoSuchElementException; +import java.util.*; import org.apache.commons.lang.NotImplementedException; @@ -58,6 +56,9 @@ public final class ColumnarBatch { // True if the row is filtered. private final boolean[] filteredRows; + // Column indices that cannot have null values. + private final Set<Integer> nullFilteredColumns; + // Total number of rows that have been filtered. private int numRowsFiltered = 0; @@ -284,11 +285,23 @@ public final class ColumnarBatch { } /** - * Sets the number of rows that are valid. + * Sets the number of rows that are valid. Additionally, marks all rows as "filtered" if one or + * more of their attributes are part of a non-nullable column. */ public void setNumRows(int numRows) { assert(numRows <= this.capacity); this.numRows = numRows; + + for (int ordinal : nullFilteredColumns) { + if (columns[ordinal].numNulls != 0) { + for (int rowId = 0; rowId < numRows; rowId++) { + if (!filteredRows[rowId] && columns[ordinal].getIsNull(rowId)) { + filteredRows[rowId] = true; + ++numRowsFiltered; + } + } + } + } } /** @@ -345,15 +358,24 @@ public final class ColumnarBatch { * in this batch will not include this row. */ public final void markFiltered(int rowId) { - assert(filteredRows[rowId] == false); + assert(!filteredRows[rowId]); filteredRows[rowId] = true; ++numRowsFiltered; } + /** + * Marks a given column as non-nullable. Any row that has a NULL value for the corresponding + * attribute is filtered out. + */ + public final void filterNullsInColumn(int ordinal) { + nullFilteredColumns.add(ordinal); + } + private ColumnarBatch(StructType schema, int maxRows, MemoryMode memMode) { this.schema = schema; this.capacity = maxRows; this.columns = new ColumnVector[schema.size()]; + this.nullFilteredColumns = new HashSet<>(); this.filteredRows = new boolean[maxRows]; for (int i = 0; i < schema.fields().length; ++i) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala index 38c3618a82..15bf00e6f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -299,10 +299,100 @@ object ParquetReadBenchmark { } } + def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = { + withTempPath { dir => + withTempTable("t1", "tempTable") { + sqlContext.range(values).registerTempTable("t1") + sqlContext.sql(s"select IF(rand(1) < $fractionOfNulls, NULL, cast(id as STRING)) as c1, " + + s"IF(rand(2) < $fractionOfNulls, NULL, cast(id as STRING)) as c2 from t1") + .write.parquet(dir.getCanonicalPath) + sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + + val benchmark = new Benchmark("String with Nulls Scan", values) + + benchmark.addCase("SQL Parquet Vectorized") { iter => + sqlContext.sql("select sum(length(c2)) from tempTable where c1 is " + + "not NULL and c2 is not NULL").collect() + } + + val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray + benchmark.addCase("PR Vectorized") { num => + var sum = 0 + files.map(_.asInstanceOf[String]).foreach { p => + val reader = new UnsafeRowParquetRecordReader + try { + reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) + val batch = reader.resultBatch() + while (reader.nextBatch()) { + val rowIterator = batch.rowIterator() + while (rowIterator.hasNext) { + val row = rowIterator.next() + val value = row.getUTF8String(0) + if (!row.isNullAt(0) && !row.isNullAt(1)) sum += value.numBytes() + } + } + } finally { + reader.close() + } + } + } + + benchmark.addCase("PR Vectorized (Null Filtering)") { num => + var sum = 0L + files.map(_.asInstanceOf[String]).foreach { p => + val reader = new UnsafeRowParquetRecordReader + try { + reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) + val batch = reader.resultBatch() + batch.filterNullsInColumn(0) + batch.filterNullsInColumn(1) + while (reader.nextBatch()) { + val rowIterator = batch.rowIterator() + while (rowIterator.hasNext) { + sum += rowIterator.next().getUTF8String(0).numBytes() + } + } + } finally { + reader.close() + } + } + } + + /* + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + String with Nulls Scan (0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + SQL Parquet Vectorized 1229 / 1648 8.5 117.2 1.0X + PR Vectorized 833 / 846 12.6 79.4 1.5X + PR Vectorized (Null Filtering) 732 / 782 14.3 69.8 1.7X + + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + String with Nulls Scan (50%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + SQL Parquet Vectorized 995 / 1053 10.5 94.9 1.0X + PR Vectorized 732 / 772 14.3 69.8 1.4X + PR Vectorized (Null Filtering) 725 / 790 14.5 69.1 1.4X + + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + String with Nulls Scan (95%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + SQL Parquet Vectorized 326 / 333 32.2 31.1 1.0X + PR Vectorized 190 / 200 55.1 18.2 1.7X + PR Vectorized (Null Filtering) 168 / 172 62.2 16.1 1.9X + */ + + benchmark.run() + } + } + } + def main(args: Array[String]): Unit = { intScanBenchmark(1024 * 1024 * 15) intStringScanBenchmark(1024 * 1024 * 10) stringDictionaryScanBenchmark(1024 * 1024 * 10) partitionTableScanBenchmark(1024 * 1024 * 15) + for (fractionOfNulls <- List(0.0, 0.50, 0.95)) { + stringWithNullsScanBenchmark(1024 * 1024 * 10, fractionOfNulls) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index ed97f59ea1..fa2c74431a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -727,4 +727,33 @@ class ColumnarBatchSuite extends SparkFunSuite { test("Random nested schema") { testRandomRows(false, 30) } + + test("null filtered columns") { + val NUM_ROWS = 10 + val schema = new StructType() + .add("key", IntegerType, nullable = false) + .add("value", StringType, nullable = true) + for (numNulls <- List(0, NUM_ROWS / 2, NUM_ROWS)) { + val rows = mutable.ArrayBuffer.empty[Row] + for (i <- 0 until NUM_ROWS) { + val row = if (i < numNulls) Row.fromSeq(Seq(i, null)) else Row.fromSeq(Seq(i, i.toString)) + rows += row + } + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val batch = ColumnVectorUtils.toBatch(schema, memMode, rows.iterator.asJava) + batch.filterNullsInColumn(1) + batch.setNumRows(NUM_ROWS) + assert(batch.numRows() == NUM_ROWS) + val it = batch.rowIterator() + // Top numNulls rows should be filtered + var k = numNulls + while (it.hasNext) { + assert(it.next().getInt(0) == k) + k += 1 + } + assert(k == NUM_ROWS) + batch.close() + }} + } + } } |