aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala41
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java59
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala6
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala (renamed from core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala)60
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala43
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala19
10 files changed, 175 insertions, 80 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 7db5834687..f37c95bedc 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -215,8 +215,8 @@ class HadoopRDD[K, V](
// Sets the thread local variable for the file's name
split.inputSplit.value match {
- case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString)
- case _ => SqlNewHadoopRDD.unsetInputFileName()
+ case fs: FileSplit => SqlNewHadoopRDDState.setInputFileName(fs.getPath.toString)
+ case _ => SqlNewHadoopRDDState.unsetInputFileName()
}
// Find a function that will return the FileSystem bytes read by this thread. Do this before
@@ -256,7 +256,7 @@ class HadoopRDD[K, V](
override def close() {
if (reader != null) {
- SqlNewHadoopRDD.unsetInputFileName()
+ SqlNewHadoopRDDState.unsetInputFileName()
// Close the reader and release it. Note: it's very important that we don't close the
// reader more than once, since that exposes us to MAPREDUCE-5918 when running against
// Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic
diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala
new file mode 100644
index 0000000000..3f15fff793
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * State for SqlNewHadoopRDD objects. This is split this way because of the package splits.
+ * TODO: Move/Combine this with org.apache.spark.sql.datasources.SqlNewHadoopRDD
+ */
+private[spark] object SqlNewHadoopRDDState {
+ /**
+ * The thread variable for the name of the current file being read. This is used by
+ * the InputFileName function in Spark SQL.
+ */
+ private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] {
+ override protected def initialValue(): UTF8String = UTF8String.fromString("")
+ }
+
+ def getInputFileName(): UTF8String = inputFileName.get()
+
+ private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file))
+
+ private[spark] def unsetInputFileName(): Unit = inputFileName.remove()
+
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 33769363a0..b6979d0c82 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -17,7 +17,11 @@
package org.apache.spark.sql.catalyst.expressions;
-import java.io.*;
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.io.OutputStream;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.nio.ByteBuffer;
@@ -26,12 +30,26 @@ import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
-import com.esotericsoftware.kryo.Kryo;
-import com.esotericsoftware.kryo.KryoSerializable;
-import com.esotericsoftware.kryo.io.Input;
-import com.esotericsoftware.kryo.io.Output;
-
-import org.apache.spark.sql.types.*;
+import org.apache.spark.sql.types.ArrayType;
+import org.apache.spark.sql.types.BinaryType;
+import org.apache.spark.sql.types.BooleanType;
+import org.apache.spark.sql.types.ByteType;
+import org.apache.spark.sql.types.CalendarIntervalType;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DateType;
+import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.sql.types.DecimalType;
+import org.apache.spark.sql.types.DoubleType;
+import org.apache.spark.sql.types.FloatType;
+import org.apache.spark.sql.types.IntegerType;
+import org.apache.spark.sql.types.LongType;
+import org.apache.spark.sql.types.MapType;
+import org.apache.spark.sql.types.NullType;
+import org.apache.spark.sql.types.ShortType;
+import org.apache.spark.sql.types.StringType;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.types.TimestampType;
+import org.apache.spark.sql.types.UserDefinedType;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.bitset.BitSetMethods;
@@ -39,9 +57,23 @@ import org.apache.spark.unsafe.hash.Murmur3_x86_32;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
-import static org.apache.spark.sql.types.DataTypes.*;
+import static org.apache.spark.sql.types.DataTypes.BooleanType;
+import static org.apache.spark.sql.types.DataTypes.ByteType;
+import static org.apache.spark.sql.types.DataTypes.DateType;
+import static org.apache.spark.sql.types.DataTypes.DoubleType;
+import static org.apache.spark.sql.types.DataTypes.FloatType;
+import static org.apache.spark.sql.types.DataTypes.IntegerType;
+import static org.apache.spark.sql.types.DataTypes.LongType;
+import static org.apache.spark.sql.types.DataTypes.NullType;
+import static org.apache.spark.sql.types.DataTypes.ShortType;
+import static org.apache.spark.sql.types.DataTypes.TimestampType;
import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.KryoSerializable;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
/**
* An Unsafe implementation of Row which is backed by raw memory instead of Java objects.
*
@@ -116,11 +148,6 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
/** The size of this row's backing data, in bytes) */
private int sizeInBytes;
- private void setNotNullAt(int i) {
- assertIndexIsValid(i);
- BitSetMethods.unset(baseObject, baseOffset, i);
- }
-
/** The width of the null tracking bit set, in bytes */
private int bitSetWidthInBytes;
@@ -187,6 +214,12 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
pointTo(buf, numFields, sizeInBytes);
}
+
+ public void setNotNullAt(int i) {
+ assertIndexIsValid(i);
+ BitSetMethods.unset(baseObject, baseOffset, i);
+ }
+
@Override
public void setNullAt(int i) {
assertIndexIsValid(i);
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
index d809877817..bf215783fc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.rdd.SqlNewHadoopRDD
+import org.apache.spark.rdd.SqlNewHadoopRDDState
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.types.{DataType, StringType}
@@ -37,13 +37,13 @@ case class InputFileName() extends LeafExpression with Nondeterministic {
override protected def initInternal(): Unit = {}
override protected def evalInternal(input: InternalRow): UTF8String = {
- SqlNewHadoopRDD.getInputFileName()
+ SqlNewHadoopRDDState.getInputFileName()
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
ev.isNull = "false"
s"final ${ctx.javaType(dataType)} ${ev.value} = " +
- "org.apache.spark.rdd.SqlNewHadoopRDD.getInputFileName();"
+ "org.apache.spark.rdd.SqlNewHadoopRDDState.getInputFileName();"
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
index 8a92e489cc..dade488ca2 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
@@ -109,6 +109,19 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
private static final int DEFAULT_VAR_LEN_SIZE = 32;
/**
+ * Tries to initialize the reader for this split. Returns true if this reader supports reading
+ * this split and false otherwise.
+ */
+ public boolean tryInitialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) {
+ try {
+ initialize(inputSplit, taskAttemptContext);
+ return true;
+ } catch (Exception e) {
+ return false;
+ }
+ }
+
+ /**
* Implementation of RecordReader API.
*/
@Override
@@ -326,6 +339,7 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
} else {
rowWriters[n].write(col, bytes.array(), bytes.position(), len);
}
+ rows[n].setNotNullAt(col);
} else {
rows[n].setNullAt(col);
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index f40e603cd1..5ef3a48c56 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -323,6 +323,11 @@ private[spark] object SQLConf {
"option must be set in Hadoop Configuration. 2. This option overrides " +
"\"spark.sql.sources.outputCommitterClass\".")
+ val PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED = booleanConf(
+ key = "spark.sql.parquet.enableUnsafeRowRecordReader",
+ defaultValue = Some(true),
+ doc = "Enables using the custom ParquetUnsafeRowRecordReader.")
+
val ORC_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.orc.filterPushdown",
defaultValue = Some(false),
doc = "When true, enable filter pushdown for ORC files.")
diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
index 4d176332b6..56cb63d9ef 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
@@ -20,6 +20,8 @@ package org.apache.spark.rdd
import java.text.SimpleDateFormat
import java.util.Date
+import scala.reflect.ClassTag
+
import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
@@ -28,13 +30,12 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.executor.DataReadMethod
import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
+import org.apache.spark.sql.{SQLConf, SQLContext}
+import org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.unsafe.types.UTF8String
-import org.apache.spark.util.{Utils, SerializableConfiguration, ShutdownHookManager}
+import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager}
import org.apache.spark.{Partition => SparkPartition, _}
-import scala.reflect.ClassTag
-
private[spark] class SqlNewHadoopPartition(
rddId: Int,
@@ -61,13 +62,13 @@ private[spark] class SqlNewHadoopPartition(
* changes based on [[org.apache.spark.rdd.HadoopRDD]].
*/
private[spark] class SqlNewHadoopRDD[V: ClassTag](
- sc : SparkContext,
+ sqlContext: SQLContext,
broadcastedConf: Broadcast[SerializableConfiguration],
@transient private val initDriverSideJobFuncOpt: Option[Job => Unit],
initLocalJobFuncOpt: Option[Job => Unit],
inputFormatClass: Class[_ <: InputFormat[Void, V]],
valueClass: Class[V])
- extends RDD[V](sc, Nil)
+ extends RDD[V](sqlContext.sparkContext, Nil)
with SparkHadoopMapReduceUtil
with Logging {
@@ -99,7 +100,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
// If true, enable using the custom RecordReader for parquet. This only works for
// a subset of the types (no complex types).
protected val enableUnsafeRowParquetReader: Boolean =
- sc.conf.getBoolean("spark.parquet.enableUnsafeRowRecordReader", true)
+ sqlContext.getConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key).toBoolean
override def getPartitions: Array[SparkPartition] = {
val conf = getConf(isDriverSide = true)
@@ -120,8 +121,8 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
}
override def compute(
- theSplit: SparkPartition,
- context: TaskContext): Iterator[V] = {
+ theSplit: SparkPartition,
+ context: TaskContext): Iterator[V] = {
val iter = new Iterator[V] {
val split = theSplit.asInstanceOf[SqlNewHadoopPartition]
logInfo("Input split: " + split.serializableHadoopSplit)
@@ -132,8 +133,8 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
// Sets the thread local variable for the file's name
split.serializableHadoopSplit.value match {
- case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString)
- case _ => SqlNewHadoopRDD.unsetInputFileName()
+ case fs: FileSplit => SqlNewHadoopRDDState.setInputFileName(fs.getPath.toString)
+ case _ => SqlNewHadoopRDDState.unsetInputFileName()
}
// Find a function that will return the FileSystem bytes read by this thread. Do this before
@@ -163,15 +164,13 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
* TODO: plumb this through a different way?
*/
if (enableUnsafeRowParquetReader &&
- format.getClass.getName == "org.apache.parquet.hadoop.ParquetInputFormat") {
- // TODO: move this class to sql.execution and remove this.
- reader = Utils.classForName(
- "org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader")
- .newInstance().asInstanceOf[RecordReader[Void, V]]
- try {
- reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
- } catch {
- case e: Exception => reader = null
+ format.getClass.getName == "org.apache.parquet.hadoop.ParquetInputFormat") {
+ val parquetReader: UnsafeRowParquetRecordReader = new UnsafeRowParquetRecordReader()
+ if (!parquetReader.tryInitialize(
+ split.serializableHadoopSplit.value, hadoopAttemptContext)) {
+ parquetReader.close()
+ } else {
+ reader = parquetReader.asInstanceOf[RecordReader[Void, V]]
}
}
@@ -217,7 +216,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
private def close() {
if (reader != null) {
- SqlNewHadoopRDD.unsetInputFileName()
+ SqlNewHadoopRDDState.unsetInputFileName()
// Close the reader and release it. Note: it's very important that we don't close the
// reader more than once, since that exposes us to MAPREDUCE-5918 when running against
// Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic
@@ -235,7 +234,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
if (bytesReadCallback.isDefined) {
inputMetrics.updateBytesRead()
} else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] ||
- split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) {
+ split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) {
// If we can't get the bytes read from the FS stats, fall back to the split size,
// which may be inaccurate.
try {
@@ -276,23 +275,6 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
}
super.persist(storageLevel)
}
-}
-
-private[spark] object SqlNewHadoopRDD {
-
- /**
- * The thread variable for the name of the current file being read. This is used by
- * the InputFileName function in Spark SQL.
- */
- private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] {
- override protected def initialValue(): UTF8String = UTF8String.fromString("")
- }
-
- def getInputFileName(): UTF8String = inputFileName.get()
-
- private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file))
-
- private[spark] def unsetInputFileName(): Unit = inputFileName.remove()
/**
* Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
index cb0aab8cc0..fdd745f48e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
@@ -319,7 +319,7 @@ private[sql] class ParquetRelation(
Utils.withDummyCallSite(sqlContext.sparkContext) {
new SqlNewHadoopRDD(
- sc = sqlContext.sparkContext,
+ sqlContext = sqlContext,
broadcastedConf = broadcastedConf,
initDriverSideJobFuncOpt = Some(setInputPaths),
initLocalJobFuncOpt = Some(initLocalJobFuncOpt),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
index c8028a5ef5..cc5aae03d5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
@@ -337,29 +337,30 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
}
}
- // Renable when we can toggle custom ParquetRecordReader on/off. The custom reader does
- // not do row by row filtering (and we probably don't want to push that).
- ignore("SPARK-11661 Still pushdown filters returned by unhandledFilters") {
+ // The unsafe row RecordReader does not support row by row filtering so run it with it disabled.
+ test("SPARK-11661 Still pushdown filters returned by unhandledFilters") {
import testImplicits._
withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") {
- withTempPath { dir =>
- val path = s"${dir.getCanonicalPath}/part=1"
- (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path)
- val df = sqlContext.read.parquet(path).filter("a = 2")
-
- // This is the source RDD without Spark-side filtering.
- val childRDD =
- df
- .queryExecution
- .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter]
- .child
- .execute()
-
- // The result should be single row.
- // When a filter is pushed to Parquet, Parquet can apply it to every row.
- // So, we can check the number of rows returned from the Parquet
- // to make sure our filter pushdown work.
- assert(childRDD.count == 1)
+ withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> "false") {
+ withTempPath { dir =>
+ val path = s"${dir.getCanonicalPath}/part=1"
+ (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path)
+ val df = sqlContext.read.parquet(path).filter("a = 2")
+
+ // This is the source RDD without Spark-side filtering.
+ val childRDD =
+ df
+ .queryExecution
+ .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter]
+ .child
+ .execute()
+
+ // The result should be single row.
+ // When a filter is pushed to Parquet, Parquet can apply it to every row.
+ // So, we can check the number of rows returned from the Parquet
+ // to make sure our filter pushdown work.
+ assert(childRDD.count == 1)
+ }
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
index 177ab42f77..0c5d4887ed 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
@@ -579,6 +579,25 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
}
}
+ test("null and non-null strings") {
+ // Create a dataset where the first values are NULL and then some non-null values. The
+ // number of non-nulls needs to be bigger than the ParquetReader batch size.
+ val data = sqlContext.range(200).map { i =>
+ if (i.getLong(0) < 150) Row(None)
+ else Row("a")
+ }
+ val df = sqlContext.createDataFrame(data, StructType(StructField("col", StringType) :: Nil))
+ assert(df.agg("col" -> "count").collect().head.getLong(0) == 50)
+
+ withTempPath { dir =>
+ val path = s"${dir.getCanonicalPath}/data"
+ df.write.parquet(path)
+
+ val df2 = sqlContext.read.parquet(path)
+ assert(df2.agg("col" -> "count").collect().head.getLong(0) == 50)
+ }
+ }
+
test("read dictionary encoded decimals written as INT32") {
checkAnswer(
// Decimal column in this file is encoded using plain dictionary