aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2015-08-29 13:24:32 -0700
committerMichael Armbrust <michael@databricks.com>2015-08-29 13:24:32 -0700
commit24ffa85c002a095ffb270175ec838995d3ed5469 (patch)
tree21504dd13afa2b1d460129ffe5603d26e5818065 /sql
parent5369be806848f43cb87c76504258c4e7de930c90 (diff)
downloadspark-24ffa85c002a095ffb270175ec838995d3ed5469.tar.gz
spark-24ffa85c002a095ffb270175ec838995d3ed5469.tar.bz2
spark-24ffa85c002a095ffb270175ec838995d3ed5469.zip
[SPARK-10289] [SQL] A direct write API for testing Parquet
This PR introduces a direct write API for testing Parquet. It's a DSL flavored version of the [`writeDirect` method] [1] comes with parquet-avro testing code. With this API, it's much easier to construct arbitrary Parquet structures. It's especially useful when adding regression tests for various compatibility corner cases. Sample usage of this API can be found in the new test case added in `ParquetThriftCompatibilitySuite`. [1]: https://github.com/apache/parquet-mr/blob/apache-parquet-1.8.1/parquet-avro/src/test/java/org/apache/parquet/avro/TestArrayCompatibility.java#L945-L972 Author: Cheng Lian <lian@databricks.com> Closes #8454 from liancheng/spark-10289/parquet-testing-direct-write-api.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala84
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala100
2 files changed, 160 insertions, 24 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala
index df68432fae..91f3ce4d34 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala
@@ -17,11 +17,15 @@
package org.apache.spark.sql.execution.datasources.parquet
-import scala.collection.JavaConverters._
+import scala.collection.JavaConverters.{collectionAsScalaIterableConverter, mapAsJavaMapConverter, seqAsJavaListConverter}
+import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, PathFilter}
-import org.apache.parquet.hadoop.ParquetFileReader
-import org.apache.parquet.schema.MessageType
+import org.apache.parquet.hadoop.api.WriteSupport
+import org.apache.parquet.hadoop.api.WriteSupport.WriteContext
+import org.apache.parquet.hadoop.{ParquetFileReader, ParquetWriter}
+import org.apache.parquet.io.api.RecordConsumer
+import org.apache.parquet.schema.{MessageType, MessageTypeParser}
import org.apache.spark.sql.QueryTest
@@ -38,11 +42,10 @@ private[sql] abstract class ParquetCompatibilityTest extends QueryTest with Parq
val fs = fsPath.getFileSystem(configuration)
val parquetFiles = fs.listStatus(fsPath, new PathFilter {
override def accept(path: Path): Boolean = pathFilter(path)
- }).toSeq
+ }).toSeq.asJava
- val footers =
- ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles.asJava, true)
- footers.iterator().next().getParquetMetadata.getFileMetaData.getSchema
+ val footers = ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles, true)
+ footers.asScala.head.getParquetMetadata.getFileMetaData.getSchema
}
protected def logParquetSchema(path: String): Unit = {
@@ -53,8 +56,69 @@ private[sql] abstract class ParquetCompatibilityTest extends QueryTest with Parq
}
}
-object ParquetCompatibilityTest {
- def makeNullable[T <: AnyRef](i: Int)(f: => T): T = {
- if (i % 3 == 0) null.asInstanceOf[T] else f
+private[sql] object ParquetCompatibilityTest {
+ implicit class RecordConsumerDSL(consumer: RecordConsumer) {
+ def message(f: => Unit): Unit = {
+ consumer.startMessage()
+ f
+ consumer.endMessage()
+ }
+
+ def group(f: => Unit): Unit = {
+ consumer.startGroup()
+ f
+ consumer.endGroup()
+ }
+
+ def field(name: String, index: Int)(f: => Unit): Unit = {
+ consumer.startField(name, index)
+ f
+ consumer.endField(name, index)
+ }
+ }
+
+ /**
+ * A testing Parquet [[WriteSupport]] implementation used to write manually constructed Parquet
+ * records with arbitrary structures.
+ */
+ private class DirectWriteSupport(schema: MessageType, metadata: Map[String, String])
+ extends WriteSupport[RecordConsumer => Unit] {
+
+ private var recordConsumer: RecordConsumer = _
+
+ override def init(configuration: Configuration): WriteContext = {
+ new WriteContext(schema, metadata.asJava)
+ }
+
+ override def write(recordWriter: RecordConsumer => Unit): Unit = {
+ recordWriter.apply(recordConsumer)
+ }
+
+ override def prepareForWrite(recordConsumer: RecordConsumer): Unit = {
+ this.recordConsumer = recordConsumer
+ }
+ }
+
+ /**
+ * Writes arbitrary messages conforming to a given `schema` to a Parquet file located by `path`.
+ * Records are produced by `recordWriters`.
+ */
+ def writeDirect(path: String, schema: String, recordWriters: (RecordConsumer => Unit)*): Unit = {
+ writeDirect(path, schema, Map.empty[String, String], recordWriters: _*)
+ }
+
+ /**
+ * Writes arbitrary messages conforming to a given `schema` to a Parquet file located by `path`
+ * with given user-defined key-value `metadata`. Records are produced by `recordWriters`.
+ */
+ def writeDirect(
+ path: String,
+ schema: String,
+ metadata: Map[String, String],
+ recordWriters: (RecordConsumer => Unit)*): Unit = {
+ val messageType = MessageTypeParser.parseMessageType(schema)
+ val writeSupport = new DirectWriteSupport(messageType, metadata)
+ val parquetWriter = new ParquetWriter[RecordConsumer => Unit](new Path(path), writeSupport)
+ try recordWriters.foreach(parquetWriter.write) finally parquetWriter.close()
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala
index b789c5a106..88a3d878f9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala
@@ -33,11 +33,9 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar
""".stripMargin)
checkAnswer(sqlContext.read.parquet(parquetFilePath.toString), (0 until 10).map { i =>
- def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i)
-
val suits = Array("SPADES", "HEARTS", "DIAMONDS", "CLUBS")
- Row(
+ val nonNullablePrimitiveValues = Seq(
i % 2 == 0,
i.toByte,
(i + 1).toShort,
@@ -50,18 +48,15 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar
s"val_$i",
s"val_$i",
// Thrift ENUM values are converted to Parquet binaries containing UTF-8 strings
- suits(i % 4),
-
- nullable(i % 2 == 0: java.lang.Boolean),
- nullable(i.toByte: java.lang.Byte),
- nullable((i + 1).toShort: java.lang.Short),
- nullable(i + 2: Integer),
- nullable((i * 10).toLong: java.lang.Long),
- nullable(i.toDouble + 0.2d: java.lang.Double),
- nullable(s"val_$i"),
- nullable(s"val_$i"),
- nullable(suits(i % 4)),
+ suits(i % 4))
+
+ val nullablePrimitiveValues = if (i % 3 == 0) {
+ Seq.fill(nonNullablePrimitiveValues.length)(null)
+ } else {
+ nonNullablePrimitiveValues
+ }
+ val complexValues = Seq(
Seq.tabulate(3)(n => s"arr_${i + n}"),
// Thrift `SET`s are converted to Parquet `LIST`s
Seq(i),
@@ -71,6 +66,83 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar
Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}")
}
}.toMap)
+
+ Row(nonNullablePrimitiveValues ++ nullablePrimitiveValues ++ complexValues: _*)
})
}
+
+ test("SPARK-10136 list of primitive list") {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+
+ // This Parquet schema is translated from the following Thrift schema:
+ //
+ // struct ListOfPrimitiveList {
+ // 1: list<list<i32>> f;
+ // }
+ val schema =
+ s"""message ListOfPrimitiveList {
+ | required group f (LIST) {
+ | repeated group f_tuple (LIST) {
+ | repeated int32 f_tuple_tuple;
+ | }
+ | }
+ |}
+ """.stripMargin
+
+ writeDirect(path, schema, { rc =>
+ rc.message {
+ rc.field("f", 0) {
+ rc.group {
+ rc.field("f_tuple", 0) {
+ rc.group {
+ rc.field("f_tuple_tuple", 0) {
+ rc.addInteger(0)
+ rc.addInteger(1)
+ }
+ }
+
+ rc.group {
+ rc.field("f_tuple_tuple", 0) {
+ rc.addInteger(2)
+ rc.addInteger(3)
+ }
+ }
+ }
+ }
+ }
+ }
+ }, { rc =>
+ rc.message {
+ rc.field("f", 0) {
+ rc.group {
+ rc.field("f_tuple", 0) {
+ rc.group {
+ rc.field("f_tuple_tuple", 0) {
+ rc.addInteger(4)
+ rc.addInteger(5)
+ }
+ }
+
+ rc.group {
+ rc.field("f_tuple_tuple", 0) {
+ rc.addInteger(6)
+ rc.addInteger(7)
+ }
+ }
+ }
+ }
+ }
+ }
+ })
+
+ logParquetSchema(path)
+
+ checkAnswer(
+ sqlContext.read.parquet(path),
+ Seq(
+ Row(Seq(Seq(0, 1), Seq(2, 3))),
+ Row(Seq(Seq(4, 5), Seq(6, 7)))))
+ }
+ }
}