aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNong Li <nong@databricks.com>2015-11-20 15:30:53 -0800
committerReynold Xin <rxin@databricks.com>2015-11-20 15:30:53 -0800
commit58b4e4f88a330135c4cec04a30d24ef91bc61d91 (patch)
tree398c3f81bf05f420c606670ed61b8f92b05f3970
parented47b1e660b830e2d4fac8d6df93f634b260393c (diff)
downloadspark-58b4e4f88a330135c4cec04a30d24ef91bc61d91.tar.gz
spark-58b4e4f88a330135c4cec04a30d24ef91bc61d91.tar.bz2
spark-58b4e4f88a330135c4cec04a30d24ef91bc61d91.zip
[SPARK-11787][SPARK-11883][SQL][FOLLOW-UP] Cleanup for this patch.
This mainly moves SqlNewHadoopRDD to the sql package. There is some state that is shared between core and I've left that in core. This allows some other associated minor cleanup. Author: Nong Li <nong@databricks.com> Closes #9845 from nongli/spark-11787.
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala41
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java59
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala6
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala (renamed from core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala)60
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala43
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala19
10 files changed, 175 insertions, 80 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 7db5834687..f37c95bedc 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -215,8 +215,8 @@ class HadoopRDD[K, V](
// Sets the thread local variable for the file's name
split.inputSplit.value match {
- case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString)
- case _ => SqlNewHadoopRDD.unsetInputFileName()
+ case fs: FileSplit => SqlNewHadoopRDDState.setInputFileName(fs.getPath.toString)
+ case _ => SqlNewHadoopRDDState.unsetInputFileName()
}
// Find a function that will return the FileSystem bytes read by this thread. Do this before
@@ -256,7 +256,7 @@ class HadoopRDD[K, V](
override def close() {
if (reader != null) {
- SqlNewHadoopRDD.unsetInputFileName()
+ SqlNewHadoopRDDState.unsetInputFileName()
// Close the reader and release it. Note: it's very important that we don't close the
// reader more than once, since that exposes us to MAPREDUCE-5918 when running against
// Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic
diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala
new file mode 100644
index 0000000000..3f15fff793
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * State for SqlNewHadoopRDD objects. This is split this way because of the package splits.
+ * TODO: Move/Combine this with org.apache.spark.sql.datasources.SqlNewHadoopRDD
+ */
+private[spark] object SqlNewHadoopRDDState {
+ /**
+ * The thread variable for the name of the current file being read. This is used by
+ * the InputFileName function in Spark SQL.
+ */
+ private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] {
+ override protected def initialValue(): UTF8String = UTF8String.fromString("")
+ }
+
+ def getInputFileName(): UTF8String = inputFileName.get()
+
+ private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file))
+
+ private[spark] def unsetInputFileName(): Unit = inputFileName.remove()
+
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 33769363a0..b6979d0c82 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -17,7 +17,11 @@
package org.apache.spark.sql.catalyst.expressions;
-import java.io.*;
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.io.OutputStream;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.nio.ByteBuffer;
@@ -26,12 +30,26 @@ import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
-import com.esotericsoftware.kryo.Kryo;
-import com.esotericsoftware.kryo.KryoSerializable;
-import com.esotericsoftware.kryo.io.Input;
-import com.esotericsoftware.kryo.io.Output;
-
-import org.apache.spark.sql.types.*;
+import org.apache.spark.sql.types.ArrayType;
+import org.apache.spark.sql.types.BinaryType;
+import org.apache.spark.sql.types.BooleanType;
+import org.apache.spark.sql.types.ByteType;
+import org.apache.spark.sql.types.CalendarIntervalType;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DateType;
+import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.sql.types.DecimalType;
+import org.apache.spark.sql.types.DoubleType;
+import org.apache.spark.sql.types.FloatType;
+import org.apache.spark.sql.types.IntegerType;
+import org.apache.spark.sql.types.LongType;
+import org.apache.spark.sql.types.MapType;
+import org.apache.spark.sql.types.NullType;
+import org.apache.spark.sql.types.ShortType;
+import org.apache.spark.sql.types.StringType;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.types.TimestampType;
+import org.apache.spark.sql.types.UserDefinedType;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.bitset.BitSetMethods;
@@ -39,9 +57,23 @@ import org.apache.spark.unsafe.hash.Murmur3_x86_32;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
-import static org.apache.spark.sql.types.DataTypes.*;
+import static org.apache.spark.sql.types.DataTypes.BooleanType;
+import static org.apache.spark.sql.types.DataTypes.ByteType;
+import static org.apache.spark.sql.types.DataTypes.DateType;
+import static org.apache.spark.sql.types.DataTypes.DoubleType;
+import static org.apache.spark.sql.types.DataTypes.FloatType;
+import static org.apache.spark.sql.types.DataTypes.IntegerType;
+import static org.apache.spark.sql.types.DataTypes.LongType;
+import static org.apache.spark.sql.types.DataTypes.NullType;
+import static org.apache.spark.sql.types.DataTypes.ShortType;
+import static org.apache.spark.sql.types.DataTypes.TimestampType;
import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.KryoSerializable;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
/**
* An Unsafe implementation of Row which is backed by raw memory instead of Java objects.
*
@@ -116,11 +148,6 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
/** The size of this row's backing data, in bytes) */
private int sizeInBytes;
- private void setNotNullAt(int i) {
- assertIndexIsValid(i);
- BitSetMethods.unset(baseObject, baseOffset, i);
- }
-
/** The width of the null tracking bit set, in bytes */
private int bitSetWidthInBytes;
@@ -187,6 +214,12 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
pointTo(buf, numFields, sizeInBytes);
}
+
+ public void setNotNullAt(int i) {
+ assertIndexIsValid(i);
+ BitSetMethods.unset(baseObject, baseOffset, i);
+ }
+
@Override
public void setNullAt(int i) {
assertIndexIsValid(i);
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
index d809877817..bf215783fc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.rdd.SqlNewHadoopRDD
+import org.apache.spark.rdd.SqlNewHadoopRDDState
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.types.{DataType, StringType}
@@ -37,13 +37,13 @@ case class InputFileName() extends LeafExpression with Nondeterministic {
override protected def initInternal(): Unit = {}
override protected def evalInternal(input: InternalRow): UTF8String = {
- SqlNewHadoopRDD.getInputFileName()
+ SqlNewHadoopRDDState.getInputFileName()
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
ev.isNull = "false"
s"final ${ctx.javaType(dataType)} ${ev.value} = " +
- "org.apache.spark.rdd.SqlNewHadoopRDD.getInputFileName();"
+ "org.apache.spark.rdd.SqlNewHadoopRDDState.getInputFileName();"
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
index 8a92e489cc..dade488ca2 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
@@ -109,6 +109,19 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
private static final int DEFAULT_VAR_LEN_SIZE = 32;
/**
+ * Tries to initialize the reader for this split. Returns true if this reader supports reading
+ * this split and false otherwise.
+ */
+ public boolean tryInitialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) {
+ try {
+ initialize(inputSplit, taskAttemptContext);
+ return true;
+ } catch (Exception e) {
+ return false;
+ }
+ }
+
+ /**
* Implementation of RecordReader API.
*/
@Override
@@ -326,6 +339,7 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
} else {
rowWriters[n].write(col, bytes.array(), bytes.position(), len);
}
+ rows[n].setNotNullAt(col);
} else {
rows[n].setNullAt(col);
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index f40e603cd1..5ef3a48c56 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -323,6 +323,11 @@ private[spark] object SQLConf {
"option must be set in Hadoop Configuration. 2. This option overrides " +
"\"spark.sql.sources.outputCommitterClass\".")
+ val PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED = booleanConf(
+ key = "spark.sql.parquet.enableUnsafeRowRecordReader",
+ defaultValue = Some(true),
+ doc = "Enables using the custom ParquetUnsafeRowRecordReader.")
+
val ORC_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.orc.filterPushdown",
defaultValue = Some(false),
doc = "When true, enable filter pushdown for ORC files.")
diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
index 4d176332b6..56cb63d9ef 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
@@ -20,6 +20,8 @@ package org.apache.spark.rdd
import java.text.SimpleDateFormat
import java.util.Date
+import scala.reflect.ClassTag
+
import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
@@ -28,13 +30,12 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.executor.DataReadMethod
import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
+import org.apache.spark.sql.{SQLConf, SQLContext}
+import org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.unsafe.types.UTF8String
-import org.apache.spark.util.{Utils, SerializableConfiguration, ShutdownHookManager}
+import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager}
import org.apache.spark.{Partition => SparkPartition, _}
-import scala.reflect.ClassTag
-
private[spark] class SqlNewHadoopPartition(
rddId: Int,
@@ -61,13 +62,13 @@ private[spark] class SqlNewHadoopPartition(
* changes based on [[org.apache.spark.rdd.HadoopRDD]].
*/
private[spark] class SqlNewHadoopRDD[V: ClassTag](
- sc : SparkContext,
+ sqlContext: SQLContext,
broadcastedConf: Broadcast[SerializableConfiguration],
@transient private val initDriverSideJobFuncOpt: Option[Job => Unit],
initLocalJobFuncOpt: Option[Job => Unit],
inputFormatClass: Class[_ <: InputFormat[Void, V]],
valueClass: Class[V])
- extends RDD[V](sc, Nil)
+ extends RDD[V](sqlContext.sparkContext, Nil)
with SparkHadoopMapReduceUtil
with Logging {
@@ -99,7 +100,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
// If true, enable using the custom RecordReader for parquet. This only works for
// a subset of the types (no complex types).
protected val enableUnsafeRowParquetReader: Boolean =
- sc.conf.getBoolean("spark.parquet.enableUnsafeRowRecordReader", true)
+ sqlContext.getConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key).toBoolean
override def getPartitions: Array[SparkPartition] = {
val conf = getConf(isDriverSide = true)
@@ -120,8 +121,8 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
}
override def compute(
- theSplit: SparkPartition,
- context: TaskContext): Iterator[V] = {
+ theSplit: SparkPartition,
+ context: TaskContext): Iterator[V] = {
val iter = new Iterator[V] {
val split = theSplit.asInstanceOf[SqlNewHadoopPartition]
logInfo("Input split: " + split.serializableHadoopSplit)
@@ -132,8 +133,8 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
// Sets the thread local variable for the file's name
split.serializableHadoopSplit.value match {
- case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString)
- case _ => SqlNewHadoopRDD.unsetInputFileName()
+ case fs: FileSplit => SqlNewHadoopRDDState.setInputFileName(fs.getPath.toString)
+ case _ => SqlNewHadoopRDDState.unsetInputFileName()
}
// Find a function that will return the FileSystem bytes read by this thread. Do this before
@@ -163,15 +164,13 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
* TODO: plumb this through a different way?
*/
if (enableUnsafeRowParquetReader &&
- format.getClass.getName == "org.apache.parquet.hadoop.ParquetInputFormat") {
- // TODO: move this class to sql.execution and remove this.
- reader = Utils.classForName(
- "org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader")
- .newInstance().asInstanceOf[RecordReader[Void, V]]
- try {
- reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
- } catch {
- case e: Exception => reader = null
+ format.getClass.getName == "org.apache.parquet.hadoop.ParquetInputFormat") {
+ val parquetReader: UnsafeRowParquetRecordReader = new UnsafeRowParquetRecordReader()
+ if (!parquetReader.tryInitialize(
+ split.serializableHadoopSplit.value, hadoopAttemptContext)) {
+ parquetReader.close()
+ } else {
+ reader = parquetReader.asInstanceOf[RecordReader[Void, V]]
}
}
@@ -217,7 +216,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
private def close() {
if (reader != null) {
- SqlNewHadoopRDD.unsetInputFileName()
+ SqlNewHadoopRDDState.unsetInputFileName()
// Close the reader and release it. Note: it's very important that we don't close the
// reader more than once, since that exposes us to MAPREDUCE-5918 when running against
// Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic
@@ -235,7 +234,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
if (bytesReadCallback.isDefined) {
inputMetrics.updateBytesRead()
} else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] ||
- split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) {
+ split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) {
// If we can't get the bytes read from the FS stats, fall back to the split size,
// which may be inaccurate.
try {
@@ -276,23 +275,6 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
}
super.persist(storageLevel)
}
-}
-
-private[spark] object SqlNewHadoopRDD {
-
- /**
- * The thread variable for the name of the current file being read. This is used by
- * the InputFileName function in Spark SQL.
- */
- private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] {
- override protected def initialValue(): UTF8String = UTF8String.fromString("")
- }
-
- def getInputFileName(): UTF8String = inputFileName.get()
-
- private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file))
-
- private[spark] def unsetInputFileName(): Unit = inputFileName.remove()
/**
* Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
index cb0aab8cc0..fdd745f48e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
@@ -319,7 +319,7 @@ private[sql] class ParquetRelation(
Utils.withDummyCallSite(sqlContext.sparkContext) {
new SqlNewHadoopRDD(
- sc = sqlContext.sparkContext,
+ sqlContext = sqlContext,
broadcastedConf = broadcastedConf,
initDriverSideJobFuncOpt = Some(setInputPaths),
initLocalJobFuncOpt = Some(initLocalJobFuncOpt),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
index c8028a5ef5..cc5aae03d5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
@@ -337,29 +337,30 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
}
}
- // Renable when we can toggle custom ParquetRecordReader on/off. The custom reader does
- // not do row by row filtering (and we probably don't want to push that).
- ignore("SPARK-11661 Still pushdown filters returned by unhandledFilters") {
+ // The unsafe row RecordReader does not support row by row filtering so run it with it disabled.
+ test("SPARK-11661 Still pushdown filters returned by unhandledFilters") {
import testImplicits._
withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") {
- withTempPath { dir =>
- val path = s"${dir.getCanonicalPath}/part=1"
- (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path)
- val df = sqlContext.read.parquet(path).filter("a = 2")
-
- // This is the source RDD without Spark-side filtering.
- val childRDD =
- df
- .queryExecution
- .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter]
- .child
- .execute()
-
- // The result should be single row.
- // When a filter is pushed to Parquet, Parquet can apply it to every row.
- // So, we can check the number of rows returned from the Parquet
- // to make sure our filter pushdown work.
- assert(childRDD.count == 1)
+ withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> "false") {
+ withTempPath { dir =>
+ val path = s"${dir.getCanonicalPath}/part=1"
+ (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path)
+ val df = sqlContext.read.parquet(path).filter("a = 2")
+
+ // This is the source RDD without Spark-side filtering.
+ val childRDD =
+ df
+ .queryExecution
+ .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter]
+ .child
+ .execute()
+
+ // The result should be single row.
+ // When a filter is pushed to Parquet, Parquet can apply it to every row.
+ // So, we can check the number of rows returned from the Parquet
+ // to make sure our filter pushdown work.
+ assert(childRDD.count == 1)
+ }
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
index 177ab42f77..0c5d4887ed 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
@@ -579,6 +579,25 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
}
}
+ test("null and non-null strings") {
+ // Create a dataset where the first values are NULL and then some non-null values. The
+ // number of non-nulls needs to be bigger than the ParquetReader batch size.
+ val data = sqlContext.range(200).map { i =>
+ if (i.getLong(0) < 150) Row(None)
+ else Row("a")
+ }
+ val df = sqlContext.createDataFrame(data, StructType(StructField("col", StringType) :: Nil))
+ assert(df.agg("col" -> "count").collect().head.getLong(0) == 50)
+
+ withTempPath { dir =>
+ val path = s"${dir.getCanonicalPath}/data"
+ df.write.parquet(path)
+
+ val df2 = sqlContext.read.parquet(path)
+ assert(df2.agg("col" -> "count").collect().head.getLong(0) == 50)
+ }
+ }
+
test("read dictionary encoded decimals written as INT32") {
checkAnswer(
// Decimal column in this file is encoded using plain dictionary