aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala229
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala28
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala16
6 files changed, 202 insertions, 90 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
index 1eb5571579..1a4ac06c7a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
@@ -24,9 +24,7 @@ import org.apache.spark.sql.catalyst.types.DataType
/**
* The data type representing [[DynamicRow]] values.
*/
-case object DynamicType extends DataType {
- def simpleString: String = "dynamic"
-}
+case object DynamicType extends DataType
/**
* Wrap a [[Row]] as a [[DynamicRow]].
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
index ac043d4dd8..1d375b8754 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
@@ -19,71 +19,125 @@ package org.apache.spark.sql.catalyst.types
import java.sql.Timestamp
-import scala.math.Numeric.{FloatAsIfIntegral, BigDecimalAsIfIntegral, DoubleAsIfIntegral}
+import scala.math.Numeric.{BigDecimalAsIfIntegral, DoubleAsIfIntegral, FloatAsIfIntegral}
import scala.reflect.ClassTag
-import scala.reflect.runtime.universe.{typeTag, TypeTag, runtimeMirror}
+import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag}
import scala.util.parsing.combinator.RegexParsers
+import org.json4s.JsonAST.JValue
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
import org.apache.spark.sql.catalyst.ScalaReflectionLock
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
import org.apache.spark.util.Utils
-/**
- * Utility functions for working with DataTypes.
- */
-object DataType extends RegexParsers {
- protected lazy val primitiveType: Parser[DataType] =
- "StringType" ^^^ StringType |
- "FloatType" ^^^ FloatType |
- "IntegerType" ^^^ IntegerType |
- "ByteType" ^^^ ByteType |
- "ShortType" ^^^ ShortType |
- "DoubleType" ^^^ DoubleType |
- "LongType" ^^^ LongType |
- "BinaryType" ^^^ BinaryType |
- "BooleanType" ^^^ BooleanType |
- "DecimalType" ^^^ DecimalType |
- "TimestampType" ^^^ TimestampType
-
- protected lazy val arrayType: Parser[DataType] =
- "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ {
- case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull)
- }
- protected lazy val mapType: Parser[DataType] =
- "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ {
- case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull)
+object DataType {
+ def fromJson(json: String): DataType = parseDataType(parse(json))
+
+ private object JSortedObject {
+ def unapplySeq(value: JValue): Option[List[(String, JValue)]] = value match {
+ case JObject(seq) => Some(seq.toList.sortBy(_._1))
+ case _ => None
}
+ }
+
+ // NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side.
+ private def parseDataType(json: JValue): DataType = json match {
+ case JString(name) =>
+ PrimitiveType.nameToType(name)
+
+ case JSortedObject(
+ ("containsNull", JBool(n)),
+ ("elementType", t: JValue),
+ ("type", JString("array"))) =>
+ ArrayType(parseDataType(t), n)
+
+ case JSortedObject(
+ ("keyType", k: JValue),
+ ("type", JString("map")),
+ ("valueContainsNull", JBool(n)),
+ ("valueType", v: JValue)) =>
+ MapType(parseDataType(k), parseDataType(v), n)
+
+ case JSortedObject(
+ ("fields", JArray(fields)),
+ ("type", JString("struct"))) =>
+ StructType(fields.map(parseStructField))
+ }
- protected lazy val structField: Parser[StructField] =
- ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ {
- case name ~ tpe ~ nullable =>
+ private def parseStructField(json: JValue): StructField = json match {
+ case JSortedObject(
+ ("name", JString(name)),
+ ("nullable", JBool(nullable)),
+ ("type", dataType: JValue)) =>
+ StructField(name, parseDataType(dataType), nullable)
+ }
+
+ @deprecated("Use DataType.fromJson instead")
+ def fromCaseClassString(string: String): DataType = CaseClassStringParser(string)
+
+ private object CaseClassStringParser extends RegexParsers {
+ protected lazy val primitiveType: Parser[DataType] =
+ ( "StringType" ^^^ StringType
+ | "FloatType" ^^^ FloatType
+ | "IntegerType" ^^^ IntegerType
+ | "ByteType" ^^^ ByteType
+ | "ShortType" ^^^ ShortType
+ | "DoubleType" ^^^ DoubleType
+ | "LongType" ^^^ LongType
+ | "BinaryType" ^^^ BinaryType
+ | "BooleanType" ^^^ BooleanType
+ | "DecimalType" ^^^ DecimalType
+ | "TimestampType" ^^^ TimestampType
+ )
+
+ protected lazy val arrayType: Parser[DataType] =
+ "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ {
+ case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull)
+ }
+
+ protected lazy val mapType: Parser[DataType] =
+ "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ {
+ case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull)
+ }
+
+ protected lazy val structField: Parser[StructField] =
+ ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ {
+ case name ~ tpe ~ nullable =>
StructField(name, tpe, nullable = nullable)
- }
+ }
- protected lazy val boolVal: Parser[Boolean] =
- "true" ^^^ true |
- "false" ^^^ false
+ protected lazy val boolVal: Parser[Boolean] =
+ ( "true" ^^^ true
+ | "false" ^^^ false
+ )
- protected lazy val structType: Parser[DataType] =
- "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ {
- case fields => new StructType(fields)
- }
+ protected lazy val structType: Parser[DataType] =
+ "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ {
+ case fields => new StructType(fields)
+ }
- protected lazy val dataType: Parser[DataType] =
- arrayType |
- mapType |
- structType |
- primitiveType
+ protected lazy val dataType: Parser[DataType] =
+ ( arrayType
+ | mapType
+ | structType
+ | primitiveType
+ )
+
+ /**
+ * Parses a string representation of a DataType.
+ *
+ * TODO: Generate parser as pickler...
+ */
+ def apply(asString: String): DataType = parseAll(dataType, asString) match {
+ case Success(result, _) => result
+ case failure: NoSuccess =>
+ throw new IllegalArgumentException(s"Unsupported dataType: $asString, $failure")
+ }
- /**
- * Parses a string representation of a DataType.
- *
- * TODO: Generate parser as pickler...
- */
- def apply(asString: String): DataType = parseAll(dataType, asString) match {
- case Success(result, _) => result
- case failure: NoSuccess => sys.error(s"Unsupported dataType: $asString, $failure")
}
protected[types] def buildFormattedString(
@@ -111,15 +165,19 @@ abstract class DataType {
def isPrimitive: Boolean = false
- def simpleString: String
-}
+ def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase
+
+ private[sql] def jsonValue: JValue = typeName
-case object NullType extends DataType {
- def simpleString: String = "null"
+ def json: String = compact(render(jsonValue))
+
+ def prettyJson: String = pretty(render(jsonValue))
}
+case object NullType extends DataType
+
object NativeType {
- def all = Seq(
+ val all = Seq(
IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType)
def unapply(dt: DataType): Boolean = all.contains(dt)
@@ -139,6 +197,12 @@ trait PrimitiveType extends DataType {
override def isPrimitive = true
}
+object PrimitiveType {
+ private[sql] val all = Seq(DecimalType, TimestampType, BinaryType) ++ NativeType.all
+
+ private[sql] val nameToType = all.map(t => t.typeName -> t).toMap
+}
+
abstract class NativeType extends DataType {
private[sql] type JvmType
@transient private[sql] val tag: TypeTag[JvmType]
@@ -154,7 +218,6 @@ case object StringType extends NativeType with PrimitiveType {
private[sql] type JvmType = String
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val ordering = implicitly[Ordering[JvmType]]
- def simpleString: String = "string"
}
case object BinaryType extends NativeType with PrimitiveType {
@@ -166,17 +229,15 @@ case object BinaryType extends NativeType with PrimitiveType {
val res = x(i).compareTo(y(i))
if (res != 0) return res
}
- return x.length - y.length
+ x.length - y.length
}
}
- def simpleString: String = "binary"
}
case object BooleanType extends NativeType with PrimitiveType {
private[sql] type JvmType = Boolean
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val ordering = implicitly[Ordering[JvmType]]
- def simpleString: String = "boolean"
}
case object TimestampType extends NativeType {
@@ -187,8 +248,6 @@ case object TimestampType extends NativeType {
private[sql] val ordering = new Ordering[JvmType] {
def compare(x: Timestamp, y: Timestamp) = x.compareTo(y)
}
-
- def simpleString: String = "timestamp"
}
abstract class NumericType extends NativeType with PrimitiveType {
@@ -222,7 +281,6 @@ case object LongType extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Long]]
private[sql] val integral = implicitly[Integral[Long]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
- def simpleString: String = "long"
}
case object IntegerType extends IntegralType {
@@ -231,7 +289,6 @@ case object IntegerType extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Int]]
private[sql] val integral = implicitly[Integral[Int]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
- def simpleString: String = "integer"
}
case object ShortType extends IntegralType {
@@ -240,7 +297,6 @@ case object ShortType extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Short]]
private[sql] val integral = implicitly[Integral[Short]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
- def simpleString: String = "short"
}
case object ByteType extends IntegralType {
@@ -249,7 +305,6 @@ case object ByteType extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Byte]]
private[sql] val integral = implicitly[Integral[Byte]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
- def simpleString: String = "byte"
}
/** Matcher for any expressions that evaluate to [[FractionalType]]s */
@@ -271,7 +326,6 @@ case object DecimalType extends FractionalType {
private[sql] val fractional = implicitly[Fractional[BigDecimal]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
private[sql] val asIntegral = BigDecimalAsIfIntegral
- def simpleString: String = "decimal"
}
case object DoubleType extends FractionalType {
@@ -281,7 +335,6 @@ case object DoubleType extends FractionalType {
private[sql] val fractional = implicitly[Fractional[Double]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
private[sql] val asIntegral = DoubleAsIfIntegral
- def simpleString: String = "double"
}
case object FloatType extends FractionalType {
@@ -291,12 +344,12 @@ case object FloatType extends FractionalType {
private[sql] val fractional = implicitly[Fractional[Float]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
private[sql] val asIntegral = FloatAsIfIntegral
- def simpleString: String = "float"
}
object ArrayType {
/** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */
def apply(elementType: DataType): ArrayType = ArrayType(elementType, true)
+ def typeName: String = "array"
}
/**
@@ -309,11 +362,14 @@ object ArrayType {
case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType {
private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
builder.append(
- s"${prefix}-- element: ${elementType.simpleString} (containsNull = ${containsNull})\n")
+ s"$prefix-- element: ${elementType.typeName} (containsNull = $containsNull)\n")
DataType.buildFormattedString(elementType, s"$prefix |", builder)
}
- def simpleString: String = "array"
+ override private[sql] def jsonValue =
+ ("type" -> typeName) ~
+ ("elementType" -> elementType.jsonValue) ~
+ ("containsNull" -> containsNull)
}
/**
@@ -325,14 +381,22 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
case class StructField(name: String, dataType: DataType, nullable: Boolean) {
private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
- builder.append(s"${prefix}-- ${name}: ${dataType.simpleString} (nullable = ${nullable})\n")
+ builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n")
DataType.buildFormattedString(dataType, s"$prefix |", builder)
}
+
+ private[sql] def jsonValue: JValue = {
+ ("name" -> name) ~
+ ("type" -> dataType.jsonValue) ~
+ ("nullable" -> nullable)
+ }
}
object StructType {
protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType =
StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable)))
+
+ def typeName = "struct"
}
case class StructType(fields: Seq[StructField]) extends DataType {
@@ -348,8 +412,7 @@ case class StructType(fields: Seq[StructField]) extends DataType {
* have a name matching the given name, `null` will be returned.
*/
def apply(name: String): StructField = {
- nameToField.get(name).getOrElse(
- throw new IllegalArgumentException(s"Field ${name} does not exist."))
+ nameToField.getOrElse(name, throw new IllegalArgumentException(s"Field $name does not exist."))
}
/**
@@ -358,7 +421,7 @@ case class StructType(fields: Seq[StructField]) extends DataType {
*/
def apply(names: Set[String]): StructType = {
val nonExistFields = names -- fieldNamesSet
- if (!nonExistFields.isEmpty) {
+ if (nonExistFields.nonEmpty) {
throw new IllegalArgumentException(
s"Field ${nonExistFields.mkString(",")} does not exist.")
}
@@ -384,7 +447,9 @@ case class StructType(fields: Seq[StructField]) extends DataType {
fields.foreach(field => field.buildFormattedString(prefix, builder))
}
- def simpleString: String = "struct"
+ override private[sql] def jsonValue =
+ ("type" -> typeName) ~
+ ("fields" -> fields.map(_.jsonValue))
}
object MapType {
@@ -394,6 +459,8 @@ object MapType {
*/
def apply(keyType: DataType, valueType: DataType): MapType =
MapType(keyType: DataType, valueType: DataType, true)
+
+ def simpleName = "map"
}
/**
@@ -407,12 +474,16 @@ case class MapType(
valueType: DataType,
valueContainsNull: Boolean) extends DataType {
private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
- builder.append(s"${prefix}-- key: ${keyType.simpleString}\n")
- builder.append(s"${prefix}-- value: ${valueType.simpleString} " +
- s"(valueContainsNull = ${valueContainsNull})\n")
+ builder.append(s"$prefix-- key: ${keyType.typeName}\n")
+ builder.append(s"$prefix-- value: ${valueType.typeName} " +
+ s"(valueContainsNull = $valueContainsNull)\n")
DataType.buildFormattedString(keyType, s"$prefix |", builder)
DataType.buildFormattedString(valueType, s"$prefix |", builder)
}
- def simpleString: String = "map"
+ override private[sql] def jsonValue: JValue =
+ ("type" -> typeName) ~
+ ("keyType" -> keyType.jsonValue) ~
+ ("valueType" -> valueType.jsonValue) ~
+ ("valueContainsNull" -> valueContainsNull)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 7a55c5bf97..35561cac3e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.TypeTag
import org.apache.hadoop.conf.Configuration
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.ScalaReflection
@@ -31,12 +32,11 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.catalyst.types.DataType
import org.apache.spark.sql.columnar.InMemoryRelation
-import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.SparkStrategies
+import org.apache.spark.sql.execution.{SparkStrategies, _}
import org.apache.spark.sql.json._
import org.apache.spark.sql.parquet.ParquetRelation
-import org.apache.spark.{Logging, SparkContext}
/**
* :: AlphaComponent ::
@@ -409,8 +409,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* It is only used by PySpark.
*/
private[sql] def parseDataType(dataTypeString: String): DataType = {
- val parser = org.apache.spark.sql.catalyst.types.DataType
- parser(dataTypeString)
+ DataType.fromJson(dataTypeString)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
index 2941b97935..e6389cf77a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.parquet
import java.io.IOException
+import scala.util.Try
+
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.mapreduce.Job
@@ -323,14 +325,14 @@ private[parquet] object ParquetTypesConverter extends Logging {
}
def convertFromString(string: String): Seq[Attribute] = {
- DataType(string) match {
+ Try(DataType.fromJson(string)).getOrElse(DataType.fromCaseClassString(string)) match {
case s: StructType => s.toAttributes
case other => sys.error(s"Can convert $string to row")
}
}
def convertToString(schema: Seq[Attribute]): String = {
- StructType.fromAttributes(schema).toString
+ StructType.fromAttributes(schema).json
}
def writeMetaData(attributes: Seq[Attribute], origPath: Path, conf: Configuration): Unit = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
index 8fb59c5830..100ecb45e9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql
import org.scalatest.FunSuite
+import org.apache.spark.sql.catalyst.types.DataType
+
class DataTypeSuite extends FunSuite {
test("construct an ArrayType") {
@@ -55,4 +57,30 @@ class DataTypeSuite extends FunSuite {
struct(Set("b", "d", "e", "f"))
}
}
+
+ def checkDataTypeJsonRepr(dataType: DataType): Unit = {
+ test(s"JSON - $dataType") {
+ assert(DataType.fromJson(dataType.json) === dataType)
+ }
+ }
+
+ checkDataTypeJsonRepr(BooleanType)
+ checkDataTypeJsonRepr(ByteType)
+ checkDataTypeJsonRepr(ShortType)
+ checkDataTypeJsonRepr(IntegerType)
+ checkDataTypeJsonRepr(LongType)
+ checkDataTypeJsonRepr(FloatType)
+ checkDataTypeJsonRepr(DoubleType)
+ checkDataTypeJsonRepr(DecimalType)
+ checkDataTypeJsonRepr(TimestampType)
+ checkDataTypeJsonRepr(StringType)
+ checkDataTypeJsonRepr(BinaryType)
+ checkDataTypeJsonRepr(ArrayType(DoubleType, true))
+ checkDataTypeJsonRepr(ArrayType(StringType, false))
+ checkDataTypeJsonRepr(MapType(IntegerType, StringType, true))
+ checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false))
+ checkDataTypeJsonRepr(
+ StructType(Seq(
+ StructField("a", IntegerType, nullable = true),
+ StructField("b", ArrayType(DoubleType), nullable = false))))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index 07adf73140..25e41ecf28 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -789,7 +789,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
assert(result3(0)(1) === "the answer")
Utils.deleteRecursively(tmpdir)
}
-
+
test("Querying on empty parquet throws exception (SPARK-3536)") {
val tmpdir = Utils.createTempDir()
Utils.deleteRecursively(tmpdir)
@@ -798,4 +798,18 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
assert(result1.size === 0)
Utils.deleteRecursively(tmpdir)
}
+
+ test("DataType string parser compatibility") {
+ val schema = StructType(List(
+ StructField("c1", IntegerType, false),
+ StructField("c2", BinaryType, false)))
+
+ val fromCaseClassString = ParquetTypesConverter.convertFromString(schema.toString)
+ val fromJson = ParquetTypesConverter.convertFromString(schema.json)
+
+ (fromCaseClassString, fromJson).zipped.foreach { (a, b) =>
+ assert(a.name == b.name)
+ assert(a.dataType === b.dataType)
+ }
+ }
}