diff options
2 files changed, 160 insertions, 24 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala
index df68432fae..91f3ce4d34 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala
@@ -17,11 +17,15 @@
package org.apache.spark.sql.execution.datasources.parquet
-import scala.collection.JavaConverters._
+import scala.collection.JavaConverters.{collectionAsScalaIterableConverter, mapAsJavaMapConverter, seqAsJavaListConverter}
+import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, PathFilter}
-import org.apache.parquet.hadoop.ParquetFileReader
-import org.apache.parquet.schema.MessageType
+import org.apache.parquet.hadoop.api.WriteSupport
+import org.apache.parquet.hadoop.api.WriteSupport.WriteContext
+import org.apache.parquet.hadoop.{ParquetFileReader, ParquetWriter}
+import org.apache.parquet.io.api.RecordConsumer
+import org.apache.parquet.schema.{MessageType, MessageTypeParser}
import org.apache.spark.sql.QueryTest
@@ -38,11 +42,10 @@ private[sql] abstract class ParquetCompatibilityTest extends QueryTest with Parq
val fs = fsPath.getFileSystem(configuration)
val parquetFiles = fs.listStatus(fsPath, new PathFilter {
override def accept(path: Path): Boolean = pathFilter(path)
- }).toSeq
+ }).toSeq.asJava
- val footers =
- ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles.asJava, true)
- footers.iterator().next().getParquetMetadata.getFileMetaData.getSchema
+ val footers = ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles, true)
+ footers.asScala.head.getParquetMetadata.getFileMetaData.getSchema
protected def logParquetSchema(path: String): Unit = {
@@ -53,8 +56,69 @@ private[sql] abstract class ParquetCompatibilityTest extends QueryTest with Parq
-object ParquetCompatibilityTest {
- def makeNullable[T <: AnyRef](i: Int)(f: => T): T = {
- if (i % 3 == 0) null.asInstanceOf[T] else f
+private[sql] object ParquetCompatibilityTest {
+ implicit class RecordConsumerDSL(consumer: RecordConsumer) {
+ def message(f: => Unit): Unit = {
+ consumer.startMessage()
+ f
+ consumer.endMessage()
+ }
+ def group(f: => Unit): Unit = {
+ consumer.startGroup()
+ f
+ consumer.endGroup()
+ }
+ def field(name: String, index: Int)(f: => Unit): Unit = {
+ consumer.startField(name, index)
+ f
+ consumer.endField(name, index)
+ }
+ }
+ /**
+ * A testing Parquet [[WriteSupport]] implementation used to write manually constructed Parquet
+ * records with arbitrary structures.
+ */
+ private class DirectWriteSupport(schema: MessageType, metadata: Map[String, String])
+ extends WriteSupport[RecordConsumer => Unit] {
+ private var recordConsumer: RecordConsumer = _
+ override def init(configuration: Configuration): WriteContext = {
+ new WriteContext(schema, metadata.asJava)
+ }
+ override def write(recordWriter: RecordConsumer => Unit): Unit = {
+ recordWriter.apply(recordConsumer)
+ }
+ override def prepareForWrite(recordConsumer: RecordConsumer): Unit = {
+ this.recordConsumer = recordConsumer
+ }
+ }
+ /**
+ * Writes arbitrary messages conforming to a given `schema` to a Parquet file located by `path`.
+ * Records are produced by `recordWriters`.
+ */
+ def writeDirect(path: String, schema: String, recordWriters: (RecordConsumer => Unit)*): Unit = {
+ writeDirect(path, schema, Map.empty[String, String], recordWriters: _*)
+ }
+ /**
+ * Writes arbitrary messages conforming to a given `schema` to a Parquet file located by `path`
+ * with given user-defined key-value `metadata`. Records are produced by `recordWriters`.
+ */
+ def writeDirect(
+ path: String,
+ schema: String,
+ metadata: Map[String, String],
+ recordWriters: (RecordConsumer => Unit)*): Unit = {
+ val messageType = MessageTypeParser.parseMessageType(schema)
+ val writeSupport = new DirectWriteSupport(messageType, metadata)
+ val parquetWriter = new ParquetWriter[RecordConsumer => Unit](new Path(path), writeSupport)
+ try recordWriters.foreach(parquetWriter.write) finally parquetWriter.close()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala
index b789c5a106..88a3d878f9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala
@@ -33,11 +33,9 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar
checkAnswer(sqlContext.read.parquet(parquetFilePath.toString), (0 until 10).map { i =>
- def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i)
val suits = Array("SPADES", "HEARTS", "DIAMONDS", "CLUBS")
- Row(
+ val nonNullablePrimitiveValues = Seq(
i % 2 == 0,
(i + 1).toShort,
@@ -50,18 +48,15 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar
// Thrift ENUM values are converted to Parquet binaries containing UTF-8 strings
- suits(i % 4),
- nullable(i % 2 == 0: java.lang.Boolean),
- nullable(i.toByte: java.lang.Byte),
- nullable((i + 1).toShort: java.lang.Short),
- nullable(i + 2: Integer),
- nullable((i * 10).toLong: java.lang.Long),
- nullable(i.toDouble + 0.2d: java.lang.Double),
- nullable(s"val_$i"),
- nullable(s"val_$i"),
- nullable(suits(i % 4)),
+ suits(i % 4))
+ val nullablePrimitiveValues = if (i % 3 == 0) {
+ Seq.fill(nonNullablePrimitiveValues.length)(null)
+ } else {
+ nonNullablePrimitiveValues
+ }
+ val complexValues = Seq(
Seq.tabulate(3)(n => s"arr_${i + n}"),
// Thrift `SET`s are converted to Parquet `LIST`s
@@ -71,6 +66,83 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar
Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}")
+ Row(nonNullablePrimitiveValues ++ nullablePrimitiveValues ++ complexValues: _*)
+ test("SPARK-10136 list of primitive list") {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ // This Parquet schema is translated from the following Thrift schema:
+ //
+ // struct ListOfPrimitiveList {
+ // 1: list<list<i32>> f;
+ // }
+ val schema =
+ s"""message ListOfPrimitiveList {
+ | required group f (LIST) {
+ | repeated group f_tuple (LIST) {
+ | repeated int32 f_tuple_tuple;
+ | }
+ | }
+ |}
+ """.stripMargin
+ writeDirect(path, schema, { rc =>
+ rc.message {
+ rc.field("f", 0) {
+ rc.group {
+ rc.field("f_tuple", 0) {
+ rc.group {
+ rc.field("f_tuple_tuple", 0) {
+ rc.addInteger(0)
+ rc.addInteger(1)
+ }
+ }
+ rc.group {
+ rc.field("f_tuple_tuple", 0) {
+ rc.addInteger(2)
+ rc.addInteger(3)
+ }
+ }
+ }
+ }
+ }
+ }
+ }, { rc =>
+ rc.message {
+ rc.field("f", 0) {
+ rc.group {
+ rc.field("f_tuple", 0) {
+ rc.group {
+ rc.field("f_tuple_tuple", 0) {
+ rc.addInteger(4)
+ rc.addInteger(5)
+ }
+ }
+ rc.group {
+ rc.field("f_tuple_tuple", 0) {
+ rc.addInteger(6)
+ rc.addInteger(7)
+ }
+ }
+ }
+ }
+ }
+ }
+ })
+ logParquetSchema(path)
+ checkAnswer(
+ sqlContext.read.parquet(path),
+ Seq(
+ Row(Seq(Seq(0, 1), Seq(2, 3))),
+ Row(Seq(Seq(4, 5), Seq(6, 7)))))
+ }
+ }