aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCheng Lian <lian.cs.zju@gmail.com>2014-10-08 17:04:49 -0700
committerMichael Armbrust <michael@databricks.com>2014-10-08 17:04:49 -0700
commita42cc08d219c579019f613faa8d310e6069c06fe (patch)
tree47adb5abf147cd477a88e33524de43b29379b990
parenta85f24accd3266e0f97ee04d03c22b593d99c062 (diff)
downloadspark-a42cc08d219c579019f613faa8d310e6069c06fe.tar.gz
spark-a42cc08d219c579019f613faa8d310e6069c06fe.tar.bz2
spark-a42cc08d219c579019f613faa8d310e6069c06fe.zip
[SPARK-3713][SQL] Uses JSON to serialize DataType objects
This PR uses JSON instead of `toString` to serialize `DataType`s. The latter is not only hard to parse but also flaky in many cases. Since we already write schema information to Parquet metadata in the old style, we have to reserve the old `DataType` parser and ensure downward compatibility. The old parser is now renamed to `CaseClassStringParser` and moved into `object DataType`. JoshRosen davies Please help review PySpark related changes, thanks! Author: Cheng Lian <lian.cs.zju@gmail.com> Closes #2563 from liancheng/datatype-to-json and squashes the following commits: fc92eb3 [Cheng Lian] Reverts debugging code, simplifies primitive type JSON representation 438c75f [Cheng Lian] Refactors PySpark DataType JSON SerDe per comments 6b6387b [Cheng Lian] Removes debugging code 6a3ee3a [Cheng Lian] Addresses per review comments dc158b5 [Cheng Lian] Addresses PEP8 issues 99ab4ee [Cheng Lian] Adds compatibility est case for Parquet type conversion a983a6c [Cheng Lian] Adds PySpark support f608c6e [Cheng Lian] De/serializes DataType objects from/to JSON
-rw-r--r--python/pyspark/sql.py153
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala229
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala28
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala16
7 files changed, 277 insertions, 168 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 3d5a281239..d3d36eb995 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -34,6 +34,7 @@ import decimal
import datetime
import keyword
import warnings
+import json
from array import array
from operator import itemgetter
from itertools import imap
@@ -71,6 +72,18 @@ class DataType(object):
def __ne__(self, other):
return not self.__eq__(other)
+ @classmethod
+ def typeName(cls):
+ return cls.__name__[:-4].lower()
+
+ def jsonValue(self):
+ return self.typeName()
+
+ def json(self):
+ return json.dumps(self.jsonValue(),
+ separators=(',', ':'),
+ sort_keys=True)
+
class PrimitiveTypeSingleton(type):
@@ -214,6 +227,16 @@ class ArrayType(DataType):
return "ArrayType(%s,%s)" % (self.elementType,
str(self.containsNull).lower())
+ def jsonValue(self):
+ return {"type": self.typeName(),
+ "elementType": self.elementType.jsonValue(),
+ "containsNull": self.containsNull}
+
+ @classmethod
+ def fromJson(cls, json):
+ return ArrayType(_parse_datatype_json_value(json["elementType"]),
+ json["containsNull"])
+
class MapType(DataType):
@@ -254,6 +277,18 @@ class MapType(DataType):
return "MapType(%s,%s,%s)" % (self.keyType, self.valueType,
str(self.valueContainsNull).lower())
+ def jsonValue(self):
+ return {"type": self.typeName(),
+ "keyType": self.keyType.jsonValue(),
+ "valueType": self.valueType.jsonValue(),
+ "valueContainsNull": self.valueContainsNull}
+
+ @classmethod
+ def fromJson(cls, json):
+ return MapType(_parse_datatype_json_value(json["keyType"]),
+ _parse_datatype_json_value(json["valueType"]),
+ json["valueContainsNull"])
+
class StructField(DataType):
@@ -292,6 +327,17 @@ class StructField(DataType):
return "StructField(%s,%s,%s)" % (self.name, self.dataType,
str(self.nullable).lower())
+ def jsonValue(self):
+ return {"name": self.name,
+ "type": self.dataType.jsonValue(),
+ "nullable": self.nullable}
+
+ @classmethod
+ def fromJson(cls, json):
+ return StructField(json["name"],
+ _parse_datatype_json_value(json["type"]),
+ json["nullable"])
+
class StructType(DataType):
@@ -321,42 +367,30 @@ class StructType(DataType):
return ("StructType(List(%s))" %
",".join(str(field) for field in self.fields))
+ def jsonValue(self):
+ return {"type": self.typeName(),
+ "fields": [f.jsonValue() for f in self.fields]}
-def _parse_datatype_list(datatype_list_string):
- """Parses a list of comma separated data types."""
- index = 0
- datatype_list = []
- start = 0
- depth = 0
- while index < len(datatype_list_string):
- if depth == 0 and datatype_list_string[index] == ",":
- datatype_string = datatype_list_string[start:index].strip()
- datatype_list.append(_parse_datatype_string(datatype_string))
- start = index + 1
- elif datatype_list_string[index] == "(":
- depth += 1
- elif datatype_list_string[index] == ")":
- depth -= 1
+ @classmethod
+ def fromJson(cls, json):
+ return StructType([StructField.fromJson(f) for f in json["fields"]])
- index += 1
- # Handle the last data type
- datatype_string = datatype_list_string[start:index].strip()
- datatype_list.append(_parse_datatype_string(datatype_string))
- return datatype_list
+_all_primitive_types = dict((v.typeName(), v)
+ for v in globals().itervalues()
+ if type(v) is PrimitiveTypeSingleton and
+ v.__base__ == PrimitiveType)
-_all_primitive_types = dict((k, v) for k, v in globals().iteritems()
- if type(v) is PrimitiveTypeSingleton and v.__base__ == PrimitiveType)
+_all_complex_types = dict((v.typeName(), v)
+ for v in [ArrayType, MapType, StructType])
-def _parse_datatype_string(datatype_string):
- """Parses the given data type string.
-
+def _parse_datatype_json_string(json_string):
+ """Parses the given data type JSON string.
>>> def check_datatype(datatype):
- ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(str(datatype))
- ... python_datatype = _parse_datatype_string(
- ... scala_datatype.toString())
+ ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json())
+ ... python_datatype = _parse_datatype_json_string(scala_datatype.json())
... return datatype == python_datatype
>>> all(check_datatype(cls()) for cls in _all_primitive_types.values())
True
@@ -394,51 +428,14 @@ def _parse_datatype_string(datatype_string):
>>> check_datatype(complex_maptype)
True
"""
- index = datatype_string.find("(")
- if index == -1:
- # It is a primitive type.
- index = len(datatype_string)
- type_or_field = datatype_string[:index]
- rest_part = datatype_string[index + 1:len(datatype_string) - 1].strip()
-
- if type_or_field in _all_primitive_types:
- return _all_primitive_types[type_or_field]()
-
- elif type_or_field == "ArrayType":
- last_comma_index = rest_part.rfind(",")
- containsNull = True
- if rest_part[last_comma_index + 1:].strip().lower() == "false":
- containsNull = False
- elementType = _parse_datatype_string(
- rest_part[:last_comma_index].strip())
- return ArrayType(elementType, containsNull)
-
- elif type_or_field == "MapType":
- last_comma_index = rest_part.rfind(",")
- valueContainsNull = True
- if rest_part[last_comma_index + 1:].strip().lower() == "false":
- valueContainsNull = False
- keyType, valueType = _parse_datatype_list(
- rest_part[:last_comma_index].strip())
- return MapType(keyType, valueType, valueContainsNull)
-
- elif type_or_field == "StructField":
- first_comma_index = rest_part.find(",")
- name = rest_part[:first_comma_index].strip()
- last_comma_index = rest_part.rfind(",")
- nullable = True
- if rest_part[last_comma_index + 1:].strip().lower() == "false":
- nullable = False
- dataType = _parse_datatype_string(
- rest_part[first_comma_index + 1:last_comma_index].strip())
- return StructField(name, dataType, nullable)
-
- elif type_or_field == "StructType":
- # rest_part should be in the format like
- # List(StructField(field1,IntegerType,false)).
- field_list_string = rest_part[rest_part.find("(") + 1:-1]
- fields = _parse_datatype_list(field_list_string)
- return StructType(fields)
+ return _parse_datatype_json_value(json.loads(json_string))
+
+
+def _parse_datatype_json_value(json_value):
+ if type(json_value) is unicode and json_value in _all_primitive_types.keys():
+ return _all_primitive_types[json_value]()
+ else:
+ return _all_complex_types[json_value["type"]].fromJson(json_value)
# Mapping Python types to Spark SQL DateType
@@ -992,7 +989,7 @@ class SQLContext(object):
self._sc.pythonExec,
broadcast_vars,
self._sc._javaAccumulator,
- str(returnType))
+ returnType.json())
def inferSchema(self, rdd):
"""Infer and apply a schema to an RDD of L{Row}.
@@ -1128,7 +1125,7 @@ class SQLContext(object):
batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer)
jrdd = self._pythonToJava(rdd._jrdd, batched)
- srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema))
+ srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
return SchemaRDD(srdd.toJavaSchemaRDD(), self)
def registerRDDAsTable(self, rdd, tableName):
@@ -1218,7 +1215,7 @@ class SQLContext(object):
if schema is None:
srdd = self._ssql_ctx.jsonFile(path)
else:
- scala_datatype = self._ssql_ctx.parseDataType(str(schema))
+ scala_datatype = self._ssql_ctx.parseDataType(schema.json())
srdd = self._ssql_ctx.jsonFile(path, scala_datatype)
return SchemaRDD(srdd.toJavaSchemaRDD(), self)
@@ -1288,7 +1285,7 @@ class SQLContext(object):
if schema is None:
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
else:
- scala_datatype = self._ssql_ctx.parseDataType(str(schema))
+ scala_datatype = self._ssql_ctx.parseDataType(schema.json())
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
return SchemaRDD(srdd.toJavaSchemaRDD(), self)
@@ -1623,7 +1620,7 @@ class SchemaRDD(RDD):
def schema(self):
"""Returns the schema of this SchemaRDD (represented by
a L{StructType})."""
- return _parse_datatype_string(self._jschema_rdd.baseSchemaRDD().schema().toString())
+ return _parse_datatype_json_string(self._jschema_rdd.baseSchemaRDD().schema().json())
def schemaString(self):
"""Returns the output schema in the tree format."""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
index 1eb5571579..1a4ac06c7a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
@@ -24,9 +24,7 @@ import org.apache.spark.sql.catalyst.types.DataType
/**
* The data type representing [[DynamicRow]] values.
*/
-case object DynamicType extends DataType {
- def simpleString: String = "dynamic"
-}
+case object DynamicType extends DataType
/**
* Wrap a [[Row]] as a [[DynamicRow]].
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
index ac043d4dd8..1d375b8754 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
@@ -19,71 +19,125 @@ package org.apache.spark.sql.catalyst.types
import java.sql.Timestamp
-import scala.math.Numeric.{FloatAsIfIntegral, BigDecimalAsIfIntegral, DoubleAsIfIntegral}
+import scala.math.Numeric.{BigDecimalAsIfIntegral, DoubleAsIfIntegral, FloatAsIfIntegral}
import scala.reflect.ClassTag
-import scala.reflect.runtime.universe.{typeTag, TypeTag, runtimeMirror}
+import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag}
import scala.util.parsing.combinator.RegexParsers
+import org.json4s.JsonAST.JValue
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
import org.apache.spark.sql.catalyst.ScalaReflectionLock
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
import org.apache.spark.util.Utils
-/**
- * Utility functions for working with DataTypes.
- */
-object DataType extends RegexParsers {
- protected lazy val primitiveType: Parser[DataType] =
- "StringType" ^^^ StringType |
- "FloatType" ^^^ FloatType |
- "IntegerType" ^^^ IntegerType |
- "ByteType" ^^^ ByteType |
- "ShortType" ^^^ ShortType |
- "DoubleType" ^^^ DoubleType |
- "LongType" ^^^ LongType |
- "BinaryType" ^^^ BinaryType |
- "BooleanType" ^^^ BooleanType |
- "DecimalType" ^^^ DecimalType |
- "TimestampType" ^^^ TimestampType
-
- protected lazy val arrayType: Parser[DataType] =
- "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ {
- case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull)
- }
- protected lazy val mapType: Parser[DataType] =
- "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ {
- case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull)
+object DataType {
+ def fromJson(json: String): DataType = parseDataType(parse(json))
+
+ private object JSortedObject {
+ def unapplySeq(value: JValue): Option[List[(String, JValue)]] = value match {
+ case JObject(seq) => Some(seq.toList.sortBy(_._1))
+ case _ => None
}
+ }
+
+ // NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side.
+ private def parseDataType(json: JValue): DataType = json match {
+ case JString(name) =>
+ PrimitiveType.nameToType(name)
+
+ case JSortedObject(
+ ("containsNull", JBool(n)),
+ ("elementType", t: JValue),
+ ("type", JString("array"))) =>
+ ArrayType(parseDataType(t), n)
+
+ case JSortedObject(
+ ("keyType", k: JValue),
+ ("type", JString("map")),
+ ("valueContainsNull", JBool(n)),
+ ("valueType", v: JValue)) =>
+ MapType(parseDataType(k), parseDataType(v), n)
+
+ case JSortedObject(
+ ("fields", JArray(fields)),
+ ("type", JString("struct"))) =>
+ StructType(fields.map(parseStructField))
+ }
- protected lazy val structField: Parser[StructField] =
- ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ {
- case name ~ tpe ~ nullable =>
+ private def parseStructField(json: JValue): StructField = json match {
+ case JSortedObject(
+ ("name", JString(name)),
+ ("nullable", JBool(nullable)),
+ ("type", dataType: JValue)) =>
+ StructField(name, parseDataType(dataType), nullable)
+ }
+
+ @deprecated("Use DataType.fromJson instead")
+ def fromCaseClassString(string: String): DataType = CaseClassStringParser(string)
+
+ private object CaseClassStringParser extends RegexParsers {
+ protected lazy val primitiveType: Parser[DataType] =
+ ( "StringType" ^^^ StringType
+ | "FloatType" ^^^ FloatType
+ | "IntegerType" ^^^ IntegerType
+ | "ByteType" ^^^ ByteType
+ | "ShortType" ^^^ ShortType
+ | "DoubleType" ^^^ DoubleType
+ | "LongType" ^^^ LongType
+ | "BinaryType" ^^^ BinaryType
+ | "BooleanType" ^^^ BooleanType
+ | "DecimalType" ^^^ DecimalType
+ | "TimestampType" ^^^ TimestampType
+ )
+
+ protected lazy val arrayType: Parser[DataType] =
+ "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ {
+ case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull)
+ }
+
+ protected lazy val mapType: Parser[DataType] =
+ "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ {
+ case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull)
+ }
+
+ protected lazy val structField: Parser[StructField] =
+ ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ {
+ case name ~ tpe ~ nullable =>
StructField(name, tpe, nullable = nullable)
- }
+ }
- protected lazy val boolVal: Parser[Boolean] =
- "true" ^^^ true |
- "false" ^^^ false
+ protected lazy val boolVal: Parser[Boolean] =
+ ( "true" ^^^ true
+ | "false" ^^^ false
+ )
- protected lazy val structType: Parser[DataType] =
- "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ {
- case fields => new StructType(fields)
- }
+ protected lazy val structType: Parser[DataType] =
+ "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ {
+ case fields => new StructType(fields)
+ }
- protected lazy val dataType: Parser[DataType] =
- arrayType |
- mapType |
- structType |
- primitiveType
+ protected lazy val dataType: Parser[DataType] =
+ ( arrayType
+ | mapType
+ | structType
+ | primitiveType
+ )
+
+ /**
+ * Parses a string representation of a DataType.
+ *
+ * TODO: Generate parser as pickler...
+ */
+ def apply(asString: String): DataType = parseAll(dataType, asString) match {
+ case Success(result, _) => result
+ case failure: NoSuccess =>
+ throw new IllegalArgumentException(s"Unsupported dataType: $asString, $failure")
+ }
- /**
- * Parses a string representation of a DataType.
- *
- * TODO: Generate parser as pickler...
- */
- def apply(asString: String): DataType = parseAll(dataType, asString) match {
- case Success(result, _) => result
- case failure: NoSuccess => sys.error(s"Unsupported dataType: $asString, $failure")
}
protected[types] def buildFormattedString(
@@ -111,15 +165,19 @@ abstract class DataType {
def isPrimitive: Boolean = false
- def simpleString: String
-}
+ def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase
+
+ private[sql] def jsonValue: JValue = typeName
-case object NullType extends DataType {
- def simpleString: String = "null"
+ def json: String = compact(render(jsonValue))
+
+ def prettyJson: String = pretty(render(jsonValue))
}
+case object NullType extends DataType
+
object NativeType {
- def all = Seq(
+ val all = Seq(
IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType)
def unapply(dt: DataType): Boolean = all.contains(dt)
@@ -139,6 +197,12 @@ trait PrimitiveType extends DataType {
override def isPrimitive = true
}
+object PrimitiveType {
+ private[sql] val all = Seq(DecimalType, TimestampType, BinaryType) ++ NativeType.all
+
+ private[sql] val nameToType = all.map(t => t.typeName -> t).toMap
+}
+
abstract class NativeType extends DataType {
private[sql] type JvmType
@transient private[sql] val tag: TypeTag[JvmType]
@@ -154,7 +218,6 @@ case object StringType extends NativeType with PrimitiveType {
private[sql] type JvmType = String
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val ordering = implicitly[Ordering[JvmType]]
- def simpleString: String = "string"
}
case object BinaryType extends NativeType with PrimitiveType {
@@ -166,17 +229,15 @@ case object BinaryType extends NativeType with PrimitiveType {
val res = x(i).compareTo(y(i))
if (res != 0) return res
}
- return x.length - y.length
+ x.length - y.length
}
}
- def simpleString: String = "binary"
}
case object BooleanType extends NativeType with PrimitiveType {
private[sql] type JvmType = Boolean
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val ordering = implicitly[Ordering[JvmType]]
- def simpleString: String = "boolean"
}
case object TimestampType extends NativeType {
@@ -187,8 +248,6 @@ case object TimestampType extends NativeType {
private[sql] val ordering = new Ordering[JvmType] {
def compare(x: Timestamp, y: Timestamp) = x.compareTo(y)
}
-
- def simpleString: String = "timestamp"
}
abstract class NumericType extends NativeType with PrimitiveType {
@@ -222,7 +281,6 @@ case object LongType extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Long]]
private[sql] val integral = implicitly[Integral[Long]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
- def simpleString: String = "long"
}
case object IntegerType extends IntegralType {
@@ -231,7 +289,6 @@ case object IntegerType extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Int]]
private[sql] val integral = implicitly[Integral[Int]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
- def simpleString: String = "integer"
}
case object ShortType extends IntegralType {
@@ -240,7 +297,6 @@ case object ShortType extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Short]]
private[sql] val integral = implicitly[Integral[Short]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
- def simpleString: String = "short"
}
case object ByteType extends IntegralType {
@@ -249,7 +305,6 @@ case object ByteType extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Byte]]
private[sql] val integral = implicitly[Integral[Byte]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
- def simpleString: String = "byte"
}
/** Matcher for any expressions that evaluate to [[FractionalType]]s */
@@ -271,7 +326,6 @@ case object DecimalType extends FractionalType {
private[sql] val fractional = implicitly[Fractional[BigDecimal]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
private[sql] val asIntegral = BigDecimalAsIfIntegral
- def simpleString: String = "decimal"
}
case object DoubleType extends FractionalType {
@@ -281,7 +335,6 @@ case object DoubleType extends FractionalType {
private[sql] val fractional = implicitly[Fractional[Double]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
private[sql] val asIntegral = DoubleAsIfIntegral
- def simpleString: String = "double"
}
case object FloatType extends FractionalType {
@@ -291,12 +344,12 @@ case object FloatType extends FractionalType {
private[sql] val fractional = implicitly[Fractional[Float]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
private[sql] val asIntegral = FloatAsIfIntegral
- def simpleString: String = "float"
}
object ArrayType {
/** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */
def apply(elementType: DataType): ArrayType = ArrayType(elementType, true)
+ def typeName: String = "array"
}
/**
@@ -309,11 +362,14 @@ object ArrayType {
case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType {
private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
builder.append(
- s"${prefix}-- element: ${elementType.simpleString} (containsNull = ${containsNull})\n")
+ s"$prefix-- element: ${elementType.typeName} (containsNull = $containsNull)\n")
DataType.buildFormattedString(elementType, s"$prefix |", builder)
}
- def simpleString: String = "array"
+ override private[sql] def jsonValue =
+ ("type" -> typeName) ~
+ ("elementType" -> elementType.jsonValue) ~
+ ("containsNull" -> containsNull)
}
/**
@@ -325,14 +381,22 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
case class StructField(name: String, dataType: DataType, nullable: Boolean) {
private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
- builder.append(s"${prefix}-- ${name}: ${dataType.simpleString} (nullable = ${nullable})\n")
+ builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n")
DataType.buildFormattedString(dataType, s"$prefix |", builder)
}
+
+ private[sql] def jsonValue: JValue = {
+ ("name" -> name) ~
+ ("type" -> dataType.jsonValue) ~
+ ("nullable" -> nullable)
+ }
}
object StructType {
protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType =
StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable)))
+
+ def typeName = "struct"
}
case class StructType(fields: Seq[StructField]) extends DataType {
@@ -348,8 +412,7 @@ case class StructType(fields: Seq[StructField]) extends DataType {
* have a name matching the given name, `null` will be returned.
*/
def apply(name: String): StructField = {
- nameToField.get(name).getOrElse(
- throw new IllegalArgumentException(s"Field ${name} does not exist."))
+ nameToField.getOrElse(name, throw new IllegalArgumentException(s"Field $name does not exist."))
}
/**
@@ -358,7 +421,7 @@ case class StructType(fields: Seq[StructField]) extends DataType {
*/
def apply(names: Set[String]): StructType = {
val nonExistFields = names -- fieldNamesSet
- if (!nonExistFields.isEmpty) {
+ if (nonExistFields.nonEmpty) {
throw new IllegalArgumentException(
s"Field ${nonExistFields.mkString(",")} does not exist.")
}
@@ -384,7 +447,9 @@ case class StructType(fields: Seq[StructField]) extends DataType {
fields.foreach(field => field.buildFormattedString(prefix, builder))
}
- def simpleString: String = "struct"
+ override private[sql] def jsonValue =
+ ("type" -> typeName) ~
+ ("fields" -> fields.map(_.jsonValue))
}
object MapType {
@@ -394,6 +459,8 @@ object MapType {
*/
def apply(keyType: DataType, valueType: DataType): MapType =
MapType(keyType: DataType, valueType: DataType, true)
+
+ def simpleName = "map"
}
/**
@@ -407,12 +474,16 @@ case class MapType(
valueType: DataType,
valueContainsNull: Boolean) extends DataType {
private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
- builder.append(s"${prefix}-- key: ${keyType.simpleString}\n")
- builder.append(s"${prefix}-- value: ${valueType.simpleString} " +
- s"(valueContainsNull = ${valueContainsNull})\n")
+ builder.append(s"$prefix-- key: ${keyType.typeName}\n")
+ builder.append(s"$prefix-- value: ${valueType.typeName} " +
+ s"(valueContainsNull = $valueContainsNull)\n")
DataType.buildFormattedString(keyType, s"$prefix |", builder)
DataType.buildFormattedString(valueType, s"$prefix |", builder)
}
- def simpleString: String = "map"
+ override private[sql] def jsonValue: JValue =
+ ("type" -> typeName) ~
+ ("keyType" -> keyType.jsonValue) ~
+ ("valueType" -> valueType.jsonValue) ~
+ ("valueContainsNull" -> valueContainsNull)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 7a55c5bf97..35561cac3e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.TypeTag
import org.apache.hadoop.conf.Configuration
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.ScalaReflection
@@ -31,12 +32,11 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.catalyst.types.DataType
import org.apache.spark.sql.columnar.InMemoryRelation
-import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.SparkStrategies
+import org.apache.spark.sql.execution.{SparkStrategies, _}
import org.apache.spark.sql.json._
import org.apache.spark.sql.parquet.ParquetRelation
-import org.apache.spark.{Logging, SparkContext}
/**
* :: AlphaComponent ::
@@ -409,8 +409,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* It is only used by PySpark.
*/
private[sql] def parseDataType(dataTypeString: String): DataType = {
- val parser = org.apache.spark.sql.catalyst.types.DataType
- parser(dataTypeString)
+ DataType.fromJson(dataTypeString)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
index 2941b97935..e6389cf77a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.parquet
import java.io.IOException
+import scala.util.Try
+
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.mapreduce.Job
@@ -323,14 +325,14 @@ private[parquet] object ParquetTypesConverter extends Logging {
}
def convertFromString(string: String): Seq[Attribute] = {
- DataType(string) match {
+ Try(DataType.fromJson(string)).getOrElse(DataType.fromCaseClassString(string)) match {
case s: StructType => s.toAttributes
case other => sys.error(s"Can convert $string to row")
}
}
def convertToString(schema: Seq[Attribute]): String = {
- StructType.fromAttributes(schema).toString
+ StructType.fromAttributes(schema).json
}
def writeMetaData(attributes: Seq[Attribute], origPath: Path, conf: Configuration): Unit = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
index 8fb59c5830..100ecb45e9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql
import org.scalatest.FunSuite
+import org.apache.spark.sql.catalyst.types.DataType
+
class DataTypeSuite extends FunSuite {
test("construct an ArrayType") {
@@ -55,4 +57,30 @@ class DataTypeSuite extends FunSuite {
struct(Set("b", "d", "e", "f"))
}
}
+
+ def checkDataTypeJsonRepr(dataType: DataType): Unit = {
+ test(s"JSON - $dataType") {
+ assert(DataType.fromJson(dataType.json) === dataType)
+ }
+ }
+
+ checkDataTypeJsonRepr(BooleanType)
+ checkDataTypeJsonRepr(ByteType)
+ checkDataTypeJsonRepr(ShortType)
+ checkDataTypeJsonRepr(IntegerType)
+ checkDataTypeJsonRepr(LongType)
+ checkDataTypeJsonRepr(FloatType)
+ checkDataTypeJsonRepr(DoubleType)
+ checkDataTypeJsonRepr(DecimalType)
+ checkDataTypeJsonRepr(TimestampType)
+ checkDataTypeJsonRepr(StringType)
+ checkDataTypeJsonRepr(BinaryType)
+ checkDataTypeJsonRepr(ArrayType(DoubleType, true))
+ checkDataTypeJsonRepr(ArrayType(StringType, false))
+ checkDataTypeJsonRepr(MapType(IntegerType, StringType, true))
+ checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false))
+ checkDataTypeJsonRepr(
+ StructType(Seq(
+ StructField("a", IntegerType, nullable = true),
+ StructField("b", ArrayType(DoubleType), nullable = false))))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index 07adf73140..25e41ecf28 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -789,7 +789,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
assert(result3(0)(1) === "the answer")
Utils.deleteRecursively(tmpdir)
}
-
+
test("Querying on empty parquet throws exception (SPARK-3536)") {
val tmpdir = Utils.createTempDir()
Utils.deleteRecursively(tmpdir)
@@ -798,4 +798,18 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
assert(result1.size === 0)
Utils.deleteRecursively(tmpdir)
}
+
+ test("DataType string parser compatibility") {
+ val schema = StructType(List(
+ StructField("c1", IntegerType, false),
+ StructField("c2", BinaryType, false)))
+
+ val fromCaseClassString = ParquetTypesConverter.convertFromString(schema.toString)
+ val fromJson = ParquetTypesConverter.convertFromString(schema.json)
+
+ (fromCaseClassString, fromJson).zipped.foreach { (a, b) =>
+ assert(a.name == b.name)
+ assert(a.dataType === b.dataType)
+ }
+ }
}