From 62771353767b5eecf2ec6c732cab07369d784df5 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 17 Dec 2014 12:48:04 -0800 Subject: [SPARK-4493][SQL] Don't pushdown Eq, NotEq, Lt, LtEq, Gt and GtEq predicates with nulls for Parquet Predicates like `a = NULL` and `a < NULL` can't be pushed down since Parquet `Lt`, `LtEq`, `Gt`, `GtEq` doesn't accept null value. Note that `Eq` and `NotEq` can only be used with `null` to represent predicates like `a IS NULL` and `a IS NOT NULL`. However, normally this issue doesn't cause NPE because any value compared to `NULL` results `NULL`, and Spark SQL automatically optimizes out `NULL` predicate in the `SimplifyFilters` rule. Only testing code that intentionally disables the optimizer may trigger this issue. (That's why this issue is not marked as blocker and I do **NOT** think we need to backport this to branch-1.1 This PR restricts `Lt`, `LtEq`, `Gt` and `GtEq` to non-null values only, and only uses `Eq` with null value to pushdown `IsNull` and `IsNotNull`. Also, added support for Parquet `NotEq` filter for completeness and (tiny) performance gain, it's also used to pushdown `IsNotNull`. [Review on Reviewable](https://reviewable.io/reviews/apache/spark/3367) Author: Cheng Lian Closes #3367 from liancheng/filters-with-null and squashes the following commits: cc41281 [Cheng Lian] Fixes several styling issues de7de28 [Cheng Lian] Adds stricter rules for Parquet filters with null --- .../spark/sql/catalyst/expressions/literals.scala | 9 ++ .../apache/spark/sql/parquet/ParquetFilters.scala | 68 ++++++++--- .../spark/sql/parquet/ParquetQuerySuite.scala | 129 +++++++++++++++++++-- 3 files changed, 183 insertions(+), 23 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 93c1932515..94e1d37c1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -41,6 +41,15 @@ object Literal { } } +/** + * An extractor that matches non-null literal values + */ +object NonNullLiteral { + def unapply(literal: Literal): Option[(Any, DataType)] = { + Option(literal.value).map(_ => (literal.value, literal.dataType)) + } +} + /** * Extractor for retrieving Int literals. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index 6fb5f49b13..56e7d11b2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -50,12 +50,37 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[java.lang.Float]) case DoubleType => (n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + + // Binary.fromString and Binary.fromByteArray don't accept null values case StringType => - (n: String, v: Any) => - FilterApi.eq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + (n: String, v: Any) => FilterApi.eq( + binaryColumn(n), + Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) case BinaryType => - (n: String, v: Any) => - FilterApi.eq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) + (n: String, v: Any) => FilterApi.eq( + binaryColumn(n), + Option(v).map(b => Binary.fromByteArray(v.asInstanceOf[Array[Byte]])).orNull) + } + + val makeNotEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { + case BooleanType => + (n: String, v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) + case IntegerType => + (n: String, v: Any) => FilterApi.notEq(intColumn(n), v.asInstanceOf[Integer]) + case LongType => + (n: String, v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[java.lang.Long]) + case FloatType => + (n: String, v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) + case DoubleType => + (n: String, v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + case StringType => + (n: String, v: Any) => FilterApi.notEq( + binaryColumn(n), + Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + case BinaryType => + (n: String, v: Any) => FilterApi.notEq( + binaryColumn(n), + Option(v).map(b => Binary.fromByteArray(v.asInstanceOf[Array[Byte]])).orNull) } val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -126,30 +151,45 @@ private[sql] object ParquetFilters { FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) } + // NOTE: + // + // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, + // which can be casted to `false` implicitly. Please refer to the `eval` method of these + // operators and the `SimplifyFilters` rule for details. predicate match { - case EqualTo(NamedExpression(name, _), Literal(value, dataType)) if dataType != NullType => + case IsNull(NamedExpression(name, dataType)) => + makeEq.lift(dataType).map(_(name, null)) + case IsNotNull(NamedExpression(name, dataType)) => + makeNotEq.lift(dataType).map(_(name, null)) + + case EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeEq.lift(dataType).map(_(name, value)) - case EqualTo(Literal(value, dataType), NamedExpression(name, _)) if dataType != NullType => + case EqualTo(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeEq.lift(dataType).map(_(name, value)) - case LessThan(NamedExpression(name, _), Literal(value, dataType)) => + case Not(EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType))) => + makeNotEq.lift(dataType).map(_(name, value)) + case Not(EqualTo(NonNullLiteral(value, dataType), NamedExpression(name, _))) => + makeNotEq.lift(dataType).map(_(name, value)) + + case LessThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeLt.lift(dataType).map(_(name, value)) - case LessThan(Literal(value, dataType), NamedExpression(name, _)) => + case LessThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeGt.lift(dataType).map(_(name, value)) - case LessThanOrEqual(NamedExpression(name, _), Literal(value, dataType)) => + case LessThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeLtEq.lift(dataType).map(_(name, value)) - case LessThanOrEqual(Literal(value, dataType), NamedExpression(name, _)) => + case LessThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeGtEq.lift(dataType).map(_(name, value)) - case GreaterThan(NamedExpression(name, _), Literal(value, dataType)) => + case GreaterThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeGt.lift(dataType).map(_(name, value)) - case GreaterThan(Literal(value, dataType), NamedExpression(name, _)) => + case GreaterThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeLt.lift(dataType).map(_(name, value)) - case GreaterThanOrEqual(NamedExpression(name, _), Literal(value, dataType)) => + case GreaterThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeGtEq.lift(dataType).map(_(name, value)) - case GreaterThanOrEqual(Literal(value, dataType), NamedExpression(name, _)) => + case GreaterThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeLtEq.lift(dataType).map(_(name, value)) case And(lhs, rhs) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 7ee4f3c1e9..0e5635d3e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.parquet -import _root_.parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} +import parquet.filter2.predicate.{FilterPredicate, Operators} import parquet.hadoop.ParquetFileWriter import parquet.hadoop.util.ContextUtil +import parquet.io.api.Binary import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ @@ -84,7 +85,8 @@ case class NumericData(i: Int, d: Double) class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { TestData // Load test data tables. - var testRDD: SchemaRDD = null + private var testRDD: SchemaRDD = null + private val originalParquetFilterPushdownEnabled = TestSQLContext.parquetFilterPushDown override def beforeAll() { ParquetTestData.writeFile() @@ -109,13 +111,17 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA Utils.deleteRecursively(ParquetTestData.testNestedDir3) Utils.deleteRecursively(ParquetTestData.testNestedDir4) // here we should also unregister the table?? + + setConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED, originalParquetFilterPushdownEnabled.toString) } test("Read/Write All Types") { val tempDir = getTempFilePath("parquetTest").getCanonicalPath val range = (0 to 255) - val data = sparkContext.parallelize(range) - .map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0)) + val data = sparkContext.parallelize(range).map { x => + parquet.AllDataTypes( + s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0) + } data.saveAsParquetFile(tempDir) @@ -260,14 +266,15 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA test("Read/Write All Types with non-primitive type") { val tempDir = getTempFilePath("parquetTest").getCanonicalPath val range = (0 to 255) - val data = sparkContext.parallelize(range) - .map(x => AllDataTypesWithNonPrimitiveType( + val data = sparkContext.parallelize(range).map { x => + parquet.AllDataTypesWithNonPrimitiveType( s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0, (0 until x), (0 until x).map(Option(_).filter(_ % 3 == 0)), (0 until x).map(i => i -> i.toLong).toMap, (0 until x).map(i => i -> Option(i.toLong)).toMap + (x -> None), - Data((0 until x), Nested(x, s"$x")))) + parquet.Data((0 until x), parquet.Nested(x, s"$x"))) + } data.saveAsParquetFile(tempDir) checkAnswer( @@ -420,7 +427,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("save and load case class RDD with nulls as parquet") { - val data = NullReflectData(null, null, null, null, null) + val data = parquet.NullReflectData(null, null, null, null, null) val rdd = sparkContext.parallelize(data :: Nil) val file = getTempFilePath("parquet") @@ -435,7 +442,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("save and load case class RDD with Nones as parquet") { - val data = OptionalReflectData(None, None, None, None, None) + val data = parquet.OptionalReflectData(None, None, None, None, None) val rdd = sparkContext.parallelize(data :: Nil) val file = getTempFilePath("parquet") @@ -938,4 +945,108 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq) } } + + def checkFilter(predicate: Predicate, filterClass: Class[_ <: FilterPredicate]): Unit = { + val filter = ParquetFilters.createFilter(predicate) + assert(filter.isDefined) + assert(filter.get.getClass == filterClass) + } + + test("Pushdown IsNull predicate") { + checkFilter('a.int.isNull, classOf[Operators.Eq[Integer]]) + checkFilter('a.long.isNull, classOf[Operators.Eq[java.lang.Long]]) + checkFilter('a.float.isNull, classOf[Operators.Eq[java.lang.Float]]) + checkFilter('a.double.isNull, classOf[Operators.Eq[java.lang.Double]]) + checkFilter('a.string.isNull, classOf[Operators.Eq[Binary]]) + checkFilter('a.binary.isNull, classOf[Operators.Eq[Binary]]) + } + + test("Pushdown IsNotNull predicate") { + checkFilter('a.int.isNotNull, classOf[Operators.NotEq[Integer]]) + checkFilter('a.long.isNotNull, classOf[Operators.NotEq[java.lang.Long]]) + checkFilter('a.float.isNotNull, classOf[Operators.NotEq[java.lang.Float]]) + checkFilter('a.double.isNotNull, classOf[Operators.NotEq[java.lang.Double]]) + checkFilter('a.string.isNotNull, classOf[Operators.NotEq[Binary]]) + checkFilter('a.binary.isNotNull, classOf[Operators.NotEq[Binary]]) + } + + test("Pushdown EqualTo predicate") { + checkFilter('a.int === 0, classOf[Operators.Eq[Integer]]) + checkFilter('a.long === 0.toLong, classOf[Operators.Eq[java.lang.Long]]) + checkFilter('a.float === 0.toFloat, classOf[Operators.Eq[java.lang.Float]]) + checkFilter('a.double === 0.toDouble, classOf[Operators.Eq[java.lang.Double]]) + checkFilter('a.string === "foo", classOf[Operators.Eq[Binary]]) + checkFilter('a.binary === "foo".getBytes, classOf[Operators.Eq[Binary]]) + } + + test("Pushdown Not(EqualTo) predicate") { + checkFilter(!('a.int === 0), classOf[Operators.NotEq[Integer]]) + checkFilter(!('a.long === 0.toLong), classOf[Operators.NotEq[java.lang.Long]]) + checkFilter(!('a.float === 0.toFloat), classOf[Operators.NotEq[java.lang.Float]]) + checkFilter(!('a.double === 0.toDouble), classOf[Operators.NotEq[java.lang.Double]]) + checkFilter(!('a.string === "foo"), classOf[Operators.NotEq[Binary]]) + checkFilter(!('a.binary === "foo".getBytes), classOf[Operators.NotEq[Binary]]) + } + + test("Pushdown LessThan predicate") { + checkFilter('a.int < 0, classOf[Operators.Lt[Integer]]) + checkFilter('a.long < 0.toLong, classOf[Operators.Lt[java.lang.Long]]) + checkFilter('a.float < 0.toFloat, classOf[Operators.Lt[java.lang.Float]]) + checkFilter('a.double < 0.toDouble, classOf[Operators.Lt[java.lang.Double]]) + checkFilter('a.string < "foo", classOf[Operators.Lt[Binary]]) + checkFilter('a.binary < "foo".getBytes, classOf[Operators.Lt[Binary]]) + } + + test("Pushdown LessThanOrEqual predicate") { + checkFilter('a.int <= 0, classOf[Operators.LtEq[Integer]]) + checkFilter('a.long <= 0.toLong, classOf[Operators.LtEq[java.lang.Long]]) + checkFilter('a.float <= 0.toFloat, classOf[Operators.LtEq[java.lang.Float]]) + checkFilter('a.double <= 0.toDouble, classOf[Operators.LtEq[java.lang.Double]]) + checkFilter('a.string <= "foo", classOf[Operators.LtEq[Binary]]) + checkFilter('a.binary <= "foo".getBytes, classOf[Operators.LtEq[Binary]]) + } + + test("Pushdown GreaterThan predicate") { + checkFilter('a.int > 0, classOf[Operators.Gt[Integer]]) + checkFilter('a.long > 0.toLong, classOf[Operators.Gt[java.lang.Long]]) + checkFilter('a.float > 0.toFloat, classOf[Operators.Gt[java.lang.Float]]) + checkFilter('a.double > 0.toDouble, classOf[Operators.Gt[java.lang.Double]]) + checkFilter('a.string > "foo", classOf[Operators.Gt[Binary]]) + checkFilter('a.binary > "foo".getBytes, classOf[Operators.Gt[Binary]]) + } + + test("Pushdown GreaterThanOrEqual predicate") { + checkFilter('a.int >= 0, classOf[Operators.GtEq[Integer]]) + checkFilter('a.long >= 0.toLong, classOf[Operators.GtEq[java.lang.Long]]) + checkFilter('a.float >= 0.toFloat, classOf[Operators.GtEq[java.lang.Float]]) + checkFilter('a.double >= 0.toDouble, classOf[Operators.GtEq[java.lang.Double]]) + checkFilter('a.string >= "foo", classOf[Operators.GtEq[Binary]]) + checkFilter('a.binary >= "foo".getBytes, classOf[Operators.GtEq[Binary]]) + } + + test("Comparison with null should not be pushed down") { + val predicates = Seq( + 'a.int === null, + !('a.int === null), + + Literal(null) === 'a.int, + !(Literal(null) === 'a.int), + + 'a.int < null, + 'a.int <= null, + 'a.int > null, + 'a.int >= null, + + Literal(null) < 'a.int, + Literal(null) <= 'a.int, + Literal(null) > 'a.int, + Literal(null) >= 'a.int + ) + + predicates.foreach { p => + assert( + ParquetFilters.createFilter(p).isEmpty, + "Comparison predicate with null shouldn't be pushed down") + } + } } -- cgit v1.2.3