aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql.py153
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala229
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala28
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala16
7 files changed, 277 insertions, 168 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 3d5a281239..d3d36eb995 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -34,6 +34,7 @@ import decimal
import datetime
import keyword
import warnings
+import json
from array import array
from operator import itemgetter
from itertools import imap
@@ -71,6 +72,18 @@ class DataType(object):
def __ne__(self, other):
return not self.__eq__(other)
+ @classmethod
+ def typeName(cls):
+ return cls.__name__[:-4].lower()
+
+ def jsonValue(self):
+ return self.typeName()
+
+ def json(self):
+ return json.dumps(self.jsonValue(),
+ separators=(',', ':'),
+ sort_keys=True)
+
class PrimitiveTypeSingleton(type):
@@ -214,6 +227,16 @@ class ArrayType(DataType):
return "ArrayType(%s,%s)" % (self.elementType,
str(self.containsNull).lower())
+ def jsonValue(self):
+ return {"type": self.typeName(),
+ "elementType": self.elementType.jsonValue(),
+ "containsNull": self.containsNull}
+
+ @classmethod
+ def fromJson(cls, json):
+ return ArrayType(_parse_datatype_json_value(json["elementType"]),
+ json["containsNull"])
+
class MapType(DataType):
@@ -254,6 +277,18 @@ class MapType(DataType):
return "MapType(%s,%s,%s)" % (self.keyType, self.valueType,
str(self.valueContainsNull).lower())
+ def jsonValue(self):
+ return {"type": self.typeName(),
+ "keyType": self.keyType.jsonValue(),
+ "valueType": self.valueType.jsonValue(),
+ "valueContainsNull": self.valueContainsNull}
+
+ @classmethod
+ def fromJson(cls, json):
+ return MapType(_parse_datatype_json_value(json["keyType"]),
+ _parse_datatype_json_value(json["valueType"]),
+ json["valueContainsNull"])
+
class StructField(DataType):
@@ -292,6 +327,17 @@ class StructField(DataType):
return "StructField(%s,%s,%s)" % (self.name, self.dataType,
str(self.nullable).lower())
+ def jsonValue(self):
+ return {"name": self.name,
+ "type": self.dataType.jsonValue(),
+ "nullable": self.nullable}
+
+ @classmethod
+ def fromJson(cls, json):
+ return StructField(json["name"],
+ _parse_datatype_json_value(json["type"]),
+ json["nullable"])
+
class StructType(DataType):
@@ -321,42 +367,30 @@ class StructType(DataType):
return ("StructType(List(%s))" %
",".join(str(field) for field in self.fields))
+ def jsonValue(self):
+ return {"type": self.typeName(),
+ "fields": [f.jsonValue() for f in self.fields]}
-def _parse_datatype_list(datatype_list_string):
- """Parses a list of comma separated data types."""
- index = 0
- datatype_list = []
- start = 0
- depth = 0
- while index < len(datatype_list_string):
- if depth == 0 and datatype_list_string[index] == ",":
- datatype_string = datatype_list_string[start:index].strip()
- datatype_list.append(_parse_datatype_string(datatype_string))
- start = index + 1
- elif datatype_list_string[index] == "(":
- depth += 1
- elif datatype_list_string[index] == ")":
- depth -= 1
+ @classmethod
+ def fromJson(cls, json):
+ return StructType([StructField.fromJson(f) for f in json["fields"]])
- index += 1
- # Handle the last data type
- datatype_string = datatype_list_string[start:index].strip()
- datatype_list.append(_parse_datatype_string(datatype_string))
- return datatype_list
+_all_primitive_types = dict((v.typeName(), v)
+ for v in globals().itervalues()
+ if type(v) is PrimitiveTypeSingleton and
+ v.__base__ == PrimitiveType)
-_all_primitive_types = dict((k, v) for k, v in globals().iteritems()
- if type(v) is PrimitiveTypeSingleton and v.__base__ == PrimitiveType)
+_all_complex_types = dict((v.typeName(), v)
+ for v in [ArrayType, MapType, StructType])
-def _parse_datatype_string(datatype_string):
- """Parses the given data type string.
-
+def _parse_datatype_json_string(json_string):
+ """Parses the given data type JSON string.
>>> def check_datatype(datatype):
- ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(str(datatype))
- ... python_datatype = _parse_datatype_string(
- ... scala_datatype.toString())
+ ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json())
+ ... python_datatype = _parse_datatype_json_string(scala_datatype.json())
... return datatype == python_datatype
>>> all(check_datatype(cls()) for cls in _all_primitive_types.values())
True
@@ -394,51 +428,14 @@ def _parse_datatype_string(datatype_string):
>>> check_datatype(complex_maptype)
True
"""
- index = datatype_string.find("(")
- if index == -1:
- # It is a primitive type.
- index = len(datatype_string)
- type_or_field = datatype_string[:index]
- rest_part = datatype_string[index + 1:len(datatype_string) - 1].strip()
-
- if type_or_field in _all_primitive_types:
- return _all_primitive_types[type_or_field]()
-
- elif type_or_field == "ArrayType":
- last_comma_index = rest_part.rfind(",")
- containsNull = True
- if rest_part[last_comma_index + 1:].strip().lower() == "false":
- containsNull = False
- elementType = _parse_datatype_string(
- rest_part[:last_comma_index].strip())
- return ArrayType(elementType, containsNull)
-
- elif type_or_field == "MapType":
- last_comma_index = rest_part.rfind(",")
- valueContainsNull = True
- if rest_part[last_comma_index + 1:].strip().lower() == "false":
- valueContainsNull = False
- keyType, valueType = _parse_datatype_list(
- rest_part[:last_comma_index].strip())
- return MapType(keyType, valueType, valueContainsNull)
-
- elif type_or_field == "StructField":
- first_comma_index = rest_part.find(",")
- name = rest_part[:first_comma_index].strip()
- last_comma_index = rest_part.rfind(",")
- nullable = True
- if rest_part[last_comma_index + 1:].strip().lower() == "false":
- nullable = False
- dataType = _parse_datatype_string(
- rest_part[first_comma_index + 1:last_comma_index].strip())
- return StructField(name, dataType, nullable)
-
- elif type_or_field == "StructType":
- # rest_part should be in the format like
- # List(StructField(field1,IntegerType,false)).
- field_list_string = rest_part[rest_part.find("(") + 1:-1]
- fields = _parse_datatype_list(field_list_string)
- return StructType(fields)
+ return _parse_datatype_json_value(json.loads(json_string))
+
+
+def _parse_datatype_json_value(json_value):
+ if type(json_value) is unicode and json_value in _all_primitive_types.keys():
+ return _all_primitive_types[json_value]()
+ else:
+ return _all_complex_types[json_value["type"]].fromJson(json_value)
# Mapping Python types to Spark SQL DateType
@@ -992,7 +989,7 @@ class SQLContext(object):
self._sc.pythonExec,
broadcast_vars,
self._sc._javaAccumulator,
- str(returnType))
+ returnType.json())
def inferSchema(self, rdd):
"""Infer and apply a schema to an RDD of L{Row}.
@@ -1128,7 +1125,7 @@ class SQLContext(object):
batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer)
jrdd = self._pythonToJava(rdd._jrdd, batched)
- srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema))
+ srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
return SchemaRDD(srdd.toJavaSchemaRDD(), self)
def registerRDDAsTable(self, rdd, tableName):
@@ -1218,7 +1215,7 @@ class SQLContext(object):
if schema is None:
srdd = self._ssql_ctx.jsonFile(path)
else:
- scala_datatype = self._ssql_ctx.parseDataType(str(schema))
+ scala_datatype = self._ssql_ctx.parseDataType(schema.json())
srdd = self._ssql_ctx.jsonFile(path, scala_datatype)
return SchemaRDD(srdd.toJavaSchemaRDD(), self)
@@ -1288,7 +1285,7 @@ class SQLContext(object):
if schema is None:
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
else:
- scala_datatype = self._ssql_ctx.parseDataType(str(schema))
+ scala_datatype = self._ssql_ctx.parseDataType(schema.json())
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
return SchemaRDD(srdd.toJavaSchemaRDD(), self)
@@ -1623,7 +1620,7 @@ class SchemaRDD(RDD):
def schema(self):
"""Returns the schema of this SchemaRDD (represented by
a L{StructType})."""
- return _parse_datatype_string(self._jschema_rdd.baseSchemaRDD().schema().toString())
+ return _parse_datatype_json_string(self._jschema_rdd.baseSchemaRDD().schema().json())
def schemaString(self):
"""Returns the output schema in the tree format."""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
index 1eb5571579..1a4ac06c7a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
@@ -24,9 +24,7 @@ import org.apache.spark.sql.catalyst.types.DataType
/**
* The data type representing [[DynamicRow]] values.
*/
-case object DynamicType extends DataType {
- def simpleString: String = "dynamic"
-}
+case object DynamicType extends DataType
/**
* Wrap a [[Row]] as a [[DynamicRow]].
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
index ac043d4dd8..1d375b8754 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
@@ -19,71 +19,125 @@ package org.apache.spark.sql.catalyst.types
import java.sql.Timestamp
-import scala.math.Numeric.{FloatAsIfIntegral, BigDecimalAsIfIntegral, DoubleAsIfIntegral}
+import scala.math.Numeric.{BigDecimalAsIfIntegral, DoubleAsIfIntegral, FloatAsIfIntegral}
import scala.reflect.ClassTag
-import scala.reflect.runtime.universe.{typeTag, TypeTag, runtimeMirror}
+import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag}
import scala.util.parsing.combinator.RegexParsers
+import org.json4s.JsonAST.JValue
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
import org.apache.spark.sql.catalyst.ScalaReflectionLock
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
import org.apache.spark.util.Utils
-/**
- * Utility functions for working with DataTypes.
- */
-object DataType extends RegexParsers {
- protected lazy val primitiveType: Parser[DataType] =
- "StringType" ^^^ StringType |
- "FloatType" ^^^ FloatType |
- "IntegerType" ^^^ IntegerType |
- "ByteType" ^^^ ByteType |
- "ShortType" ^^^ ShortType |
- "DoubleType" ^^^ DoubleType |
- "LongType" ^^^ LongType |
- "BinaryType" ^^^ BinaryType |
- "BooleanType" ^^^ BooleanType |
- "DecimalType" ^^^ DecimalType |
- "TimestampType" ^^^ TimestampType
-
- protected lazy val arrayType: Parser[DataType] =
- "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ {
- case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull)
- }
- protected lazy val mapType: Parser[DataType] =
- "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ {
- case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull)
+object DataType {
+ def fromJson(json: String): DataType = parseDataType(parse(json))
+
+ private object JSortedObject {
+ def unapplySeq(value: JValue): Option[List[(String, JValue)]] = value match {
+ case JObject(seq) => Some(seq.toList.sortBy(_._1))
+ case _ => None
}
+ }
+
+ // NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side.
+ private def parseDataType(json: JValue): DataType = json match {
+ case JString(name) =>
+ PrimitiveType.nameToType(name)
+
+ case JSortedObject(
+ ("containsNull", JBool(n)),
+ ("elementType", t: JValue),
+ ("type", JString("array"))) =>
+ ArrayType(parseDataType(t), n)
+
+ case JSortedObject(
+ ("keyType", k: JValue),
+ ("type", JString("map")),
+ ("valueContainsNull", JBool(n)),
+ ("valueType", v: JValue)) =>
+ MapType(parseDataType(k), parseDataType(v), n)
+
+ case JSortedObject(
+ ("fields", JArray(fields)),
+ ("type", JString("struct"))) =>
+ StructType(fields.map(parseStructField))
+ }
- protected lazy val structField: Parser[StructField] =
- ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ {
- case name ~ tpe ~ nullable =>
+ private def parseStructField(json: JValue): StructField = json match {
+ case JSortedObject(
+ ("name", JString(name)),
+ ("nullable", JBool(nullable)),
+ ("type", dataType: JValue)) =>
+ StructField(name, parseDataType(dataType), nullable)
+ }
+
+ @deprecated("Use DataType.fromJson instead")
+ def fromCaseClassString(string: String): DataType = CaseClassStringParser(string)
+
+ private object CaseClassStringParser extends RegexParsers {
+ protected lazy val primitiveType: Parser[DataType] =
+ ( "StringType" ^^^ StringType
+ | "FloatType" ^^^ FloatType
+ | "IntegerType" ^^^ IntegerType
+ | "ByteType" ^^^ ByteType
+ | "ShortType" ^^^ ShortType
+ | "DoubleType" ^^^ DoubleType
+ | "LongType" ^^^ LongType
+ | "BinaryType" ^^^ BinaryType
+ | "BooleanType" ^^^ BooleanType
+ | "DecimalType" ^^^ DecimalType
+ | "TimestampType" ^^^ TimestampType
+ )
+
+ protected lazy val arrayType: Parser[DataType] =
+ "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ {
+ case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull)
+ }
+
+ protected lazy val mapType: Parser[DataType] =
+ "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ {
+ case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull)
+ }
+
+ protected lazy val structField: Parser[StructField] =
+ ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ {
+ case name ~ tpe ~ nullable =>
StructField(name, tpe, nullable = nullable)
- }
+ }
- protected lazy val boolVal: Parser[Boolean] =
- "true" ^^^ true |
- "false" ^^^ false
+ protected lazy val boolVal: Parser[Boolean] =
+ ( "true" ^^^ true
+ | "false" ^^^ false
+ )
- protected lazy val structType: Parser[DataType] =
- "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ {
- case fields => new StructType(fields)
- }
+ protected lazy val structType: Parser[DataType] =
+ "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ {
+ case fields => new StructType(fields)
+ }
- protected lazy val dataType: Parser[DataType] =
- arrayType |
- mapType |
- structType |
- primitiveType
+ protected lazy val dataType: Parser[DataType] =
+ ( arrayType
+ | mapType
+ | structType
+ | primitiveType
+ )
+
+ /**
+ * Parses a string representation of a DataType.
+ *
+ * TODO: Generate parser as pickler...
+ */
+ def apply(asString: String): DataType = parseAll(dataType, asString) match {
+ case Success(result, _) => result
+ case failure: NoSuccess =>
+ throw new IllegalArgumentException(s"Unsupported dataType: $asString, $failure")
+ }
- /**
- * Parses a string representation of a DataType.
- *
- * TODO: Generate parser as pickler...
- */
- def apply(asString: String): DataType = parseAll(dataType, asString) match {
- case Success(result, _) => result
- case failure: NoSuccess => sys.error(s"Unsupported dataType: $asString, $failure")
}
protected[types] def buildFormattedString(
@@ -111,15 +165,19 @@ abstract class DataType {
def isPrimitive: Boolean = false
- def simpleString: String
-}
+ def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase
+
+ private[sql] def jsonValue: JValue = typeName
-case object NullType extends DataType {
- def simpleString: String = "null"
+ def json: String = compact(render(jsonValue))
+
+ def prettyJson: String = pretty(render(jsonValue))
}
+case object NullType extends DataType
+
object NativeType {
- def all = Seq(
+ val all = Seq(
IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType)
def unapply(dt: DataType): Boolean = all.contains(dt)
@@ -139,6 +197,12 @@ trait PrimitiveType extends DataType {
override def isPrimitive = true
}
+object PrimitiveType {
+ private[sql] val all = Seq(DecimalType, TimestampType, BinaryType) ++ NativeType.all
+
+ private[sql] val nameToType = all.map(t => t.typeName -> t).toMap
+}
+
abstract class NativeType extends DataType {
private[sql] type JvmType
@transient private[sql] val tag: TypeTag[JvmType]
@@ -154,7 +218,6 @@ case object StringType extends NativeType with PrimitiveType {
private[sql] type JvmType = String
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val ordering = implicitly[Ordering[JvmType]]
- def simpleString: String = "string"
}
case object BinaryType extends NativeType with PrimitiveType {
@@ -166,17 +229,15 @@ case object BinaryType extends NativeType with PrimitiveType {
val res = x(i).compareTo(y(i))
if (res != 0) return res
}
- return x.length - y.length
+ x.length - y.length
}
}
- def simpleString: String = "binary"
}
case object BooleanType extends NativeType with PrimitiveType {
private[sql] type JvmType = Boolean
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val ordering = implicitly[Ordering[JvmType]]
- def simpleString: String = "boolean"
}
case object TimestampType extends NativeType {
@@ -187,8 +248,6 @@ case object TimestampType extends NativeType {
private[sql] val ordering = new Ordering[JvmType] {
def compare(x: Timestamp, y: Timestamp) = x.compareTo(y)
}
-
- def simpleString: String = "timestamp"
}
abstract class NumericType extends NativeType with PrimitiveType {
@@ -222,7 +281,6 @@ case object LongType extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Long]]
private[sql] val integral = implicitly[Integral[Long]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
- def simpleString: String = "long"
}
case object IntegerType extends IntegralType {
@@ -231,7 +289,6 @@ case object IntegerType extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Int]]
private[sql] val integral = implicitly[Integral[Int]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
- def simpleString: String = "integer"
}
case object ShortType extends IntegralType {
@@ -240,7 +297,6 @@ case object ShortType extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Short]]
private[sql] val integral = implicitly[Integral[Short]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
- def simpleString: String = "short"
}
case object ByteType extends IntegralType {
@@ -249,7 +305,6 @@ case object ByteType extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Byte]]
private[sql] val integral = implicitly[Integral[Byte]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
- def simpleString: String = "byte"
}
/** Matcher for any expressions that evaluate to [[FractionalType]]s */
@@ -271,7 +326,6 @@ case object DecimalType extends FractionalType {
private[sql] val fractional = implicitly[Fractional[BigDecimal]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
private[sql] val asIntegral = BigDecimalAsIfIntegral
- def simpleString: String = "decimal"
}
case object DoubleType extends FractionalType {
@@ -281,7 +335,6 @@ case object DoubleType extends FractionalType {
private[sql] val fractional = implicitly[Fractional[Double]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
private[sql] val asIntegral = DoubleAsIfIntegral
- def simpleString: String = "double"
}
case object FloatType extends FractionalType {
@@ -291,12 +344,12 @@ case object FloatType extends FractionalType {
private[sql] val fractional = implicitly[Fractional[Float]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
private[sql] val asIntegral = FloatAsIfIntegral
- def simpleString: String = "float"
}
object ArrayType {
/** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */
def apply(elementType: DataType): ArrayType = ArrayType(elementType, true)
+ def typeName: String = "array"
}
/**
@@ -309,11 +362,14 @@ object ArrayType {
case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType {
private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
builder.append(
- s"${prefix}-- element: ${elementType.simpleString} (containsNull = ${containsNull})\n")
+ s"$prefix-- element: ${elementType.typeName} (containsNull = $containsNull)\n")
DataType.buildFormattedString(elementType, s"$prefix |", builder)
}
- def simpleString: String = "array"
+ override private[sql] def jsonValue =
+ ("type" -> typeName) ~
+ ("elementType" -> elementType.jsonValue) ~
+ ("containsNull" -> containsNull)
}
/**
@@ -325,14 +381,22 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
case class StructField(name: String, dataType: DataType, nullable: Boolean) {
private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
- builder.append(s"${prefix}-- ${name}: ${dataType.simpleString} (nullable = ${nullable})\n")
+ builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n")
DataType.buildFormattedString(dataType, s"$prefix |", builder)
}
+
+ private[sql] def jsonValue: JValue = {
+ ("name" -> name) ~
+ ("type" -> dataType.jsonValue) ~
+ ("nullable" -> nullable)
+ }
}
object StructType {
protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType =
StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable)))
+
+ def typeName = "struct"
}
case class StructType(fields: Seq[StructField]) extends DataType {
@@ -348,8 +412,7 @@ case class StructType(fields: Seq[StructField]) extends DataType {
* have a name matching the given name, `null` will be returned.
*/
def apply(name: String): StructField = {
- nameToField.get(name).getOrElse(
- throw new IllegalArgumentException(s"Field ${name} does not exist."))
+ nameToField.getOrElse(name, throw new IllegalArgumentException(s"Field $name does not exist."))
}
/**
@@ -358,7 +421,7 @@ case class StructType(fields: Seq[StructField]) extends DataType {
*/
def apply(names: Set[String]): StructType = {
val nonExistFields = names -- fieldNamesSet
- if (!nonExistFields.isEmpty) {
+ if (nonExistFields.nonEmpty) {
throw new IllegalArgumentException(
s"Field ${nonExistFields.mkString(",")} does not exist.")
}
@@ -384,7 +447,9 @@ case class StructType(fields: Seq[StructField]) extends DataType {
fields.foreach(field => field.buildFormattedString(prefix, builder))
}
- def simpleString: String = "struct"
+ override private[sql] def jsonValue =
+ ("type" -> typeName) ~
+ ("fields" -> fields.map(_.jsonValue))
}
object MapType {
@@ -394,6 +459,8 @@ object MapType {
*/
def apply(keyType: DataType, valueType: DataType): MapType =
MapType(keyType: DataType, valueType: DataType, true)
+
+ def simpleName = "map"
}
/**
@@ -407,12 +474,16 @@ case class MapType(
valueType: DataType,
valueContainsNull: Boolean) extends DataType {
private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
- builder.append(s"${prefix}-- key: ${keyType.simpleString}\n")
- builder.append(s"${prefix}-- value: ${valueType.simpleString} " +
- s"(valueContainsNull = ${valueContainsNull})\n")
+ builder.append(s"$prefix-- key: ${keyType.typeName}\n")
+ builder.append(s"$prefix-- value: ${valueType.typeName} " +
+ s"(valueContainsNull = $valueContainsNull)\n")
DataType.buildFormattedString(keyType, s"$prefix |", builder)
DataType.buildFormattedString(valueType, s"$prefix |", builder)
}
- def simpleString: String = "map"
+ override private[sql] def jsonValue: JValue =
+ ("type" -> typeName) ~
+ ("keyType" -> keyType.jsonValue) ~
+ ("valueType" -> valueType.jsonValue) ~
+ ("valueContainsNull" -> valueContainsNull)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 7a55c5bf97..35561cac3e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.TypeTag
import org.apache.hadoop.conf.Configuration
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.ScalaReflection
@@ -31,12 +32,11 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.catalyst.types.DataType
import org.apache.spark.sql.columnar.InMemoryRelation
-import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.SparkStrategies
+import org.apache.spark.sql.execution.{SparkStrategies, _}
import org.apache.spark.sql.json._
import org.apache.spark.sql.parquet.ParquetRelation
-import org.apache.spark.{Logging, SparkContext}
/**
* :: AlphaComponent ::
@@ -409,8 +409,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* It is only used by PySpark.
*/
private[sql] def parseDataType(dataTypeString: String): DataType = {
- val parser = org.apache.spark.sql.catalyst.types.DataType
- parser(dataTypeString)
+ DataType.fromJson(dataTypeString)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
index 2941b97935..e6389cf77a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.parquet
import java.io.IOException
+import scala.util.Try
+
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.mapreduce.Job
@@ -323,14 +325,14 @@ private[parquet] object ParquetTypesConverter extends Logging {
}
def convertFromString(string: String): Seq[Attribute] = {
- DataType(string) match {
+ Try(DataType.fromJson(string)).getOrElse(DataType.fromCaseClassString(string)) match {
case s: StructType => s.toAttributes
case other => sys.error(s"Can convert $string to row")
}
}
def convertToString(schema: Seq[Attribute]): String = {
- StructType.fromAttributes(schema).toString
+ StructType.fromAttributes(schema).json
}
def writeMetaData(attributes: Seq[Attribute], origPath: Path, conf: Configuration): Unit = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
index 8fb59c5830..100ecb45e9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql
import org.scalatest.FunSuite
+import org.apache.spark.sql.catalyst.types.DataType
+
class DataTypeSuite extends FunSuite {
test("construct an ArrayType") {
@@ -55,4 +57,30 @@ class DataTypeSuite extends FunSuite {
struct(Set("b", "d", "e", "f"))
}
}
+
+ def checkDataTypeJsonRepr(dataType: DataType): Unit = {
+ test(s"JSON - $dataType") {
+ assert(DataType.fromJson(dataType.json) === dataType)
+ }
+ }
+
+ checkDataTypeJsonRepr(BooleanType)
+ checkDataTypeJsonRepr(ByteType)
+ checkDataTypeJsonRepr(ShortType)
+ checkDataTypeJsonRepr(IntegerType)
+ checkDataTypeJsonRepr(LongType)
+ checkDataTypeJsonRepr(FloatType)
+ checkDataTypeJsonRepr(DoubleType)
+ checkDataTypeJsonRepr(DecimalType)
+ checkDataTypeJsonRepr(TimestampType)
+ checkDataTypeJsonRepr(StringType)
+ checkDataTypeJsonRepr(BinaryType)
+ checkDataTypeJsonRepr(ArrayType(DoubleType, true))
+ checkDataTypeJsonRepr(ArrayType(StringType, false))
+ checkDataTypeJsonRepr(MapType(IntegerType, StringType, true))
+ checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false))
+ checkDataTypeJsonRepr(
+ StructType(Seq(
+ StructField("a", IntegerType, nullable = true),
+ StructField("b", ArrayType(DoubleType), nullable = false))))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index 07adf73140..25e41ecf28 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -789,7 +789,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
assert(result3(0)(1) === "the answer")
Utils.deleteRecursively(tmpdir)
}
-
+
test("Querying on empty parquet throws exception (SPARK-3536)") {
val tmpdir = Utils.createTempDir()
Utils.deleteRecursively(tmpdir)
@@ -798,4 +798,18 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
assert(result1.size === 0)
Utils.deleteRecursively(tmpdir)
}
+
+ test("DataType string parser compatibility") {
+ val schema = StructType(List(
+ StructField("c1", IntegerType, false),
+ StructField("c2", BinaryType, false)))
+
+ val fromCaseClassString = ParquetTypesConverter.convertFromString(schema.toString)
+ val fromJson = ParquetTypesConverter.convertFromString(schema.json)
+
+ (fromCaseClassString, fromJson).zipped.foreach { (a, b) =>
+ assert(a.name == b.name)
+ assert(a.dataType === b.dataType)
+ }
+ }
}