aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala3
-rw-r--r--project/SparkBuild.scala2
-rw-r--r--python/pyspark/sql.py567
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala20
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala15
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala45
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala268
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala66
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/api/java/types/ArrayType.java68
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/api/java/types/BinaryType.java (renamed from sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala)19
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/api/java/types/BooleanType.java27
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/api/java/types/ByteType.java27
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java190
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/api/java/types/DecimalType.java27
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/api/java/types/DoubleType.java27
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/api/java/types/FloatType.java27
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/api/java/types/IntegerType.java27
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/api/java/types/LongType.java27
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java78
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/api/java/types/ShortType.java27
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/api/java/types/StringType.java27
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructField.java76
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructType.java59
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/api/java/types/TimestampType.java27
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/api/java/types/package-info.java22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala230
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala65
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala59
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala118
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/package-info.java (renamed from sql/catalyst/src/main/scala/org/apache/spark/sql/package-info.java)0
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/package.scala409
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala110
-rw-r--r--sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java166
-rw-r--r--sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java170
-rw-r--r--sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java150
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala58
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala46
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala64
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/TestData.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala81
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala198
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala9
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala5
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala8
61 files changed, 3442 insertions, 386 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 0d8453fb18..f551a59ee3 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -544,7 +544,8 @@ private[spark] object PythonRDD extends Logging {
}
/**
- * Convert an RDD of serialized Python dictionaries to Scala Maps
+ * Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions).
+ * It is only used by pyspark.sql.
* TODO: Support more Python types.
*/
def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 490fac3cc3..e2dab0f9f7 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -312,7 +312,7 @@ object Unidoc {
"mllib.regression", "mllib.stat", "mllib.tree", "mllib.tree.configuration",
"mllib.tree.impurity", "mllib.tree.model", "mllib.util"
),
- "-group", "Spark SQL", packageList("sql.api.java", "sql.hive.api.java"),
+ "-group", "Spark SQL", packageList("sql.api.java", "sql.api.java.types", "sql.hive.api.java"),
"-noqualifier", "java.lang"
)
)
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index a6b3277db3..13f0ed4e35 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -20,7 +20,451 @@ from pyspark.serializers import BatchedSerializer, PickleSerializer
from py4j.protocol import Py4JError
-__all__ = ["SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"]
+__all__ = [
+ "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType",
+ "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType",
+ "ShortType", "ArrayType", "MapType", "StructField", "StructType",
+ "SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"]
+
+
+class PrimitiveTypeSingleton(type):
+ _instances = {}
+
+ def __call__(cls):
+ if cls not in cls._instances:
+ cls._instances[cls] = super(PrimitiveTypeSingleton, cls).__call__()
+ return cls._instances[cls]
+
+
+class StringType(object):
+ """Spark SQL StringType
+
+ The data type representing string values.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "StringType"
+
+
+class BinaryType(object):
+ """Spark SQL BinaryType
+
+ The data type representing bytearray values.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "BinaryType"
+
+
+class BooleanType(object):
+ """Spark SQL BooleanType
+
+ The data type representing bool values.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "BooleanType"
+
+
+class TimestampType(object):
+ """Spark SQL TimestampType
+
+ The data type representing datetime.datetime values.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "TimestampType"
+
+
+class DecimalType(object):
+ """Spark SQL DecimalType
+
+ The data type representing decimal.Decimal values.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "DecimalType"
+
+
+class DoubleType(object):
+ """Spark SQL DoubleType
+
+ The data type representing float values.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "DoubleType"
+
+
+class FloatType(object):
+ """Spark SQL FloatType
+
+ The data type representing single precision floating-point values.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "FloatType"
+
+
+class ByteType(object):
+ """Spark SQL ByteType
+
+ The data type representing int values with 1 singed byte.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "ByteType"
+
+
+class IntegerType(object):
+ """Spark SQL IntegerType
+
+ The data type representing int values.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "IntegerType"
+
+
+class LongType(object):
+ """Spark SQL LongType
+
+ The data type representing long values. If the any value is beyond the range of
+ [-9223372036854775808, 9223372036854775807], please use DecimalType.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "LongType"
+
+
+class ShortType(object):
+ """Spark SQL ShortType
+
+ The data type representing int values with 2 signed bytes.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "ShortType"
+
+
+class ArrayType(object):
+ """Spark SQL ArrayType
+
+ The data type representing list values.
+ An ArrayType object comprises two fields, elementType (a DataType) and containsNull (a bool).
+ The field of elementType is used to specify the type of array elements.
+ The field of containsNull is used to specify if the array has None values.
+
+ """
+ def __init__(self, elementType, containsNull=False):
+ """Creates an ArrayType
+
+ :param elementType: the data type of elements.
+ :param containsNull: indicates whether the list contains None values.
+
+ >>> ArrayType(StringType) == ArrayType(StringType, False)
+ True
+ >>> ArrayType(StringType, True) == ArrayType(StringType)
+ False
+ """
+ self.elementType = elementType
+ self.containsNull = containsNull
+
+ def __repr__(self):
+ return "ArrayType(" + self.elementType.__repr__() + "," + \
+ str(self.containsNull).lower() + ")"
+
+ def __eq__(self, other):
+ return (isinstance(other, self.__class__) and
+ self.elementType == other.elementType and
+ self.containsNull == other.containsNull)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+
+class MapType(object):
+ """Spark SQL MapType
+
+ The data type representing dict values.
+ A MapType object comprises three fields,
+ keyType (a DataType), valueType (a DataType) and valueContainsNull (a bool).
+ The field of keyType is used to specify the type of keys in the map.
+ The field of valueType is used to specify the type of values in the map.
+ The field of valueContainsNull is used to specify if values of this map has None values.
+ For values of a MapType column, keys are not allowed to have None values.
+
+ """
+ def __init__(self, keyType, valueType, valueContainsNull=True):
+ """Creates a MapType
+ :param keyType: the data type of keys.
+ :param valueType: the data type of values.
+ :param valueContainsNull: indicates whether values contains null values.
+
+ >>> MapType(StringType, IntegerType) == MapType(StringType, IntegerType, True)
+ True
+ >>> MapType(StringType, IntegerType, False) == MapType(StringType, FloatType)
+ False
+ """
+ self.keyType = keyType
+ self.valueType = valueType
+ self.valueContainsNull = valueContainsNull
+
+ def __repr__(self):
+ return "MapType(" + self.keyType.__repr__() + "," + \
+ self.valueType.__repr__() + "," + \
+ str(self.valueContainsNull).lower() + ")"
+
+ def __eq__(self, other):
+ return (isinstance(other, self.__class__) and
+ self.keyType == other.keyType and
+ self.valueType == other.valueType and
+ self.valueContainsNull == other.valueContainsNull)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+
+class StructField(object):
+ """Spark SQL StructField
+
+ Represents a field in a StructType.
+ A StructField object comprises three fields, name (a string), dataType (a DataType),
+ and nullable (a bool). The field of name is the name of a StructField. The field of
+ dataType specifies the data type of a StructField.
+ The field of nullable specifies if values of a StructField can contain None values.
+
+ """
+ def __init__(self, name, dataType, nullable):
+ """Creates a StructField
+ :param name: the name of this field.
+ :param dataType: the data type of this field.
+ :param nullable: indicates whether values of this field can be null.
+
+ >>> StructField("f1", StringType, True) == StructField("f1", StringType, True)
+ True
+ >>> StructField("f1", StringType, True) == StructField("f2", StringType, True)
+ False
+ """
+ self.name = name
+ self.dataType = dataType
+ self.nullable = nullable
+
+ def __repr__(self):
+ return "StructField(" + self.name + "," + \
+ self.dataType.__repr__() + "," + \
+ str(self.nullable).lower() + ")"
+
+ def __eq__(self, other):
+ return (isinstance(other, self.__class__) and
+ self.name == other.name and
+ self.dataType == other.dataType and
+ self.nullable == other.nullable)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+
+class StructType(object):
+ """Spark SQL StructType
+
+ The data type representing namedtuple values.
+ A StructType object comprises a list of L{StructField}s.
+
+ """
+ def __init__(self, fields):
+ """Creates a StructType
+
+ >>> struct1 = StructType([StructField("f1", StringType, True)])
+ >>> struct2 = StructType([StructField("f1", StringType, True)])
+ >>> struct1 == struct2
+ True
+ >>> struct1 = StructType([StructField("f1", StringType, True)])
+ >>> struct2 = StructType([StructField("f1", StringType, True),
+ ... [StructField("f2", IntegerType, False)]])
+ >>> struct1 == struct2
+ False
+ """
+ self.fields = fields
+
+ def __repr__(self):
+ return "StructType(List(" + \
+ ",".join([field.__repr__() for field in self.fields]) + "))"
+
+ def __eq__(self, other):
+ return (isinstance(other, self.__class__) and
+ self.fields == other.fields)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+
+def _parse_datatype_list(datatype_list_string):
+ """Parses a list of comma separated data types."""
+ index = 0
+ datatype_list = []
+ start = 0
+ depth = 0
+ while index < len(datatype_list_string):
+ if depth == 0 and datatype_list_string[index] == ",":
+ datatype_string = datatype_list_string[start:index].strip()
+ datatype_list.append(_parse_datatype_string(datatype_string))
+ start = index + 1
+ elif datatype_list_string[index] == "(":
+ depth += 1
+ elif datatype_list_string[index] == ")":
+ depth -= 1
+
+ index += 1
+
+ # Handle the last data type
+ datatype_string = datatype_list_string[start:index].strip()
+ datatype_list.append(_parse_datatype_string(datatype_string))
+ return datatype_list
+
+
+def _parse_datatype_string(datatype_string):
+ """Parses the given data type string.
+
+ >>> def check_datatype(datatype):
+ ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.__repr__())
+ ... python_datatype = _parse_datatype_string(scala_datatype.toString())
+ ... return datatype == python_datatype
+ >>> check_datatype(StringType())
+ True
+ >>> check_datatype(BinaryType())
+ True
+ >>> check_datatype(BooleanType())
+ True
+ >>> check_datatype(TimestampType())
+ True
+ >>> check_datatype(DecimalType())
+ True
+ >>> check_datatype(DoubleType())
+ True
+ >>> check_datatype(FloatType())
+ True
+ >>> check_datatype(ByteType())
+ True
+ >>> check_datatype(IntegerType())
+ True
+ >>> check_datatype(LongType())
+ True
+ >>> check_datatype(ShortType())
+ True
+ >>> # Simple ArrayType.
+ >>> simple_arraytype = ArrayType(StringType(), True)
+ >>> check_datatype(simple_arraytype)
+ True
+ >>> # Simple MapType.
+ >>> simple_maptype = MapType(StringType(), LongType())
+ >>> check_datatype(simple_maptype)
+ True
+ >>> # Simple StructType.
+ >>> simple_structtype = StructType([
+ ... StructField("a", DecimalType(), False),
+ ... StructField("b", BooleanType(), True),
+ ... StructField("c", LongType(), True),
+ ... StructField("d", BinaryType(), False)])
+ >>> check_datatype(simple_structtype)
+ True
+ >>> # Complex StructType.
+ >>> complex_structtype = StructType([
+ ... StructField("simpleArray", simple_arraytype, True),
+ ... StructField("simpleMap", simple_maptype, True),
+ ... StructField("simpleStruct", simple_structtype, True),
+ ... StructField("boolean", BooleanType(), False)])
+ >>> check_datatype(complex_structtype)
+ True
+ >>> # Complex ArrayType.
+ >>> complex_arraytype = ArrayType(complex_structtype, True)
+ >>> check_datatype(complex_arraytype)
+ True
+ >>> # Complex MapType.
+ >>> complex_maptype = MapType(complex_structtype, complex_arraytype, False)
+ >>> check_datatype(complex_maptype)
+ True
+ """
+ left_bracket_index = datatype_string.find("(")
+ if left_bracket_index == -1:
+ # It is a primitive type.
+ left_bracket_index = len(datatype_string)
+ type_or_field = datatype_string[:left_bracket_index]
+ rest_part = datatype_string[left_bracket_index+1:len(datatype_string)-1].strip()
+ if type_or_field == "StringType":
+ return StringType()
+ elif type_or_field == "BinaryType":
+ return BinaryType()
+ elif type_or_field == "BooleanType":
+ return BooleanType()
+ elif type_or_field == "TimestampType":
+ return TimestampType()
+ elif type_or_field == "DecimalType":
+ return DecimalType()
+ elif type_or_field == "DoubleType":
+ return DoubleType()
+ elif type_or_field == "FloatType":
+ return FloatType()
+ elif type_or_field == "ByteType":
+ return ByteType()
+ elif type_or_field == "IntegerType":
+ return IntegerType()
+ elif type_or_field == "LongType":
+ return LongType()
+ elif type_or_field == "ShortType":
+ return ShortType()
+ elif type_or_field == "ArrayType":
+ last_comma_index = rest_part.rfind(",")
+ containsNull = True
+ if rest_part[last_comma_index+1:].strip().lower() == "false":
+ containsNull = False
+ elementType = _parse_datatype_string(rest_part[:last_comma_index].strip())
+ return ArrayType(elementType, containsNull)
+ elif type_or_field == "MapType":
+ last_comma_index = rest_part.rfind(",")
+ valueContainsNull = True
+ if rest_part[last_comma_index+1:].strip().lower() == "false":
+ valueContainsNull = False
+ keyType, valueType = _parse_datatype_list(rest_part[:last_comma_index].strip())
+ return MapType(keyType, valueType, valueContainsNull)
+ elif type_or_field == "StructField":
+ first_comma_index = rest_part.find(",")
+ name = rest_part[:first_comma_index].strip()
+ last_comma_index = rest_part.rfind(",")
+ nullable = True
+ if rest_part[last_comma_index+1:].strip().lower() == "false":
+ nullable = False
+ dataType = _parse_datatype_string(
+ rest_part[first_comma_index+1:last_comma_index].strip())
+ return StructField(name, dataType, nullable)
+ elif type_or_field == "StructType":
+ # rest_part should be in the format like
+ # List(StructField(field1,IntegerType,false)).
+ field_list_string = rest_part[rest_part.find("(")+1:-1]
+ fields = _parse_datatype_list(field_list_string)
+ return StructType(fields)
class SQLContext:
@@ -109,6 +553,40 @@ class SQLContext:
srdd = self._ssql_ctx.inferSchema(jrdd.rdd())
return SchemaRDD(srdd, self)
+ def applySchema(self, rdd, schema):
+ """Applies the given schema to the given RDD of L{dict}s.
+
+ >>> schema = StructType([StructField("field1", IntegerType(), False),
+ ... StructField("field2", StringType(), False)])
+ >>> srdd = sqlCtx.applySchema(rdd, schema)
+ >>> sqlCtx.registerRDDAsTable(srdd, "table1")
+ >>> srdd2 = sqlCtx.sql("SELECT * from table1")
+ >>> srdd2.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"},
+ ... {"field1" : 3, "field2": "row3"}]
+ True
+ >>> from datetime import datetime
+ >>> rdd = sc.parallelize([{"byte": 127, "short": -32768, "float": 1.0,
+ ... "time": datetime(2010, 1, 1, 1, 1, 1), "map": {"a": 1}, "struct": {"b": 2},
+ ... "list": [1, 2, 3]}])
+ >>> schema = StructType([
+ ... StructField("byte", ByteType(), False),
+ ... StructField("short", ShortType(), False),
+ ... StructField("float", FloatType(), False),
+ ... StructField("time", TimestampType(), False),
+ ... StructField("map", MapType(StringType(), IntegerType(), False), False),
+ ... StructField("struct", StructType([StructField("b", ShortType(), False)]), False),
+ ... StructField("list", ArrayType(ByteType(), False), False),
+ ... StructField("null", DoubleType(), True)])
+ >>> srdd = sqlCtx.applySchema(rdd, schema).map(
+ ... lambda x: (
+ ... x.byte, x.short, x.float, x.time, x.map["a"], x.struct["b"], x.list, x.null))
+ >>> srdd.collect()[0]
+ (127, -32768, 1.0, datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
+ """
+ jrdd = self._pythonToJavaMap(rdd._jrdd)
+ srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.__repr__())
+ return SchemaRDD(srdd, self)
+
def registerRDDAsTable(self, rdd, tableName):
"""Registers the given RDD as a temporary table in the catalog.
@@ -139,10 +617,11 @@ class SQLContext:
jschema_rdd = self._ssql_ctx.parquetFile(path)
return SchemaRDD(jschema_rdd, self)
- def jsonFile(self, path):
- """Loads a text file storing one JSON object per line,
- returning the result as a L{SchemaRDD}.
- It goes through the entire dataset once to determine the schema.
+ def jsonFile(self, path, schema=None):
+ """Loads a text file storing one JSON object per line as a L{SchemaRDD}.
+
+ If the schema is provided, applies the given schema to this JSON dataset.
+ Otherwise, it goes through the entire dataset once to determine the schema.
>>> import tempfile, shutil
>>> jsonFile = tempfile.mkdtemp()
@@ -151,8 +630,8 @@ class SQLContext:
>>> for json in jsonStrings:
... print>>ofn, json
>>> ofn.close()
- >>> srdd = sqlCtx.jsonFile(jsonFile)
- >>> sqlCtx.registerRDDAsTable(srdd, "table1")
+ >>> srdd1 = sqlCtx.jsonFile(jsonFile)
+ >>> sqlCtx.registerRDDAsTable(srdd1, "table1")
>>> srdd2 = sqlCtx.sql(
... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1")
>>> srdd2.collect() == [
@@ -160,16 +639,45 @@ class SQLContext:
... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]},
... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}]
True
+ >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema())
+ >>> sqlCtx.registerRDDAsTable(srdd3, "table2")
+ >>> srdd4 = sqlCtx.sql(
+ ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table2")
+ >>> srdd4.collect() == [
+ ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None},
+ ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]},
+ ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}]
+ True
+ >>> schema = StructType([
+ ... StructField("field2", StringType(), True),
+ ... StructField("field3",
+ ... StructType([
+ ... StructField("field5", ArrayType(IntegerType(), False), True)]), False)])
+ >>> srdd5 = sqlCtx.jsonFile(jsonFile, schema)
+ >>> sqlCtx.registerRDDAsTable(srdd5, "table3")
+ >>> srdd6 = sqlCtx.sql(
+ ... "SELECT field2 AS f1, field3.field5 as f2, field3.field5[0] as f3 from table3")
+ >>> srdd6.collect() == [
+ ... {"f1": "row1", "f2": None, "f3": None},
+ ... {"f1": None, "f2": [10, 11], "f3": 10},
+ ... {"f1": "row3", "f2": [], "f3": None}]
+ True
"""
- jschema_rdd = self._ssql_ctx.jsonFile(path)
+ if schema is None:
+ jschema_rdd = self._ssql_ctx.jsonFile(path)
+ else:
+ scala_datatype = self._ssql_ctx.parseDataType(schema.__repr__())
+ jschema_rdd = self._ssql_ctx.jsonFile(path, scala_datatype)
return SchemaRDD(jschema_rdd, self)
- def jsonRDD(self, rdd):
- """Loads an RDD storing one JSON object per string, returning the result as a L{SchemaRDD}.
- It goes through the entire dataset once to determine the schema.
+ def jsonRDD(self, rdd, schema=None):
+ """Loads an RDD storing one JSON object per string as a L{SchemaRDD}.
- >>> srdd = sqlCtx.jsonRDD(json)
- >>> sqlCtx.registerRDDAsTable(srdd, "table1")
+ If the schema is provided, applies the given schema to this JSON dataset.
+ Otherwise, it goes through the entire dataset once to determine the schema.
+
+ >>> srdd1 = sqlCtx.jsonRDD(json)
+ >>> sqlCtx.registerRDDAsTable(srdd1, "table1")
>>> srdd2 = sqlCtx.sql(
... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1")
>>> srdd2.collect() == [
@@ -177,6 +685,29 @@ class SQLContext:
... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]},
... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}]
True
+ >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema())
+ >>> sqlCtx.registerRDDAsTable(srdd3, "table2")
+ >>> srdd4 = sqlCtx.sql(
+ ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table2")
+ >>> srdd4.collect() == [
+ ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None},
+ ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]},
+ ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}]
+ True
+ >>> schema = StructType([
+ ... StructField("field2", StringType(), True),
+ ... StructField("field3",
+ ... StructType([
+ ... StructField("field5", ArrayType(IntegerType(), False), True)]), False)])
+ >>> srdd5 = sqlCtx.jsonRDD(json, schema)
+ >>> sqlCtx.registerRDDAsTable(srdd5, "table3")
+ >>> srdd6 = sqlCtx.sql(
+ ... "SELECT field2 AS f1, field3.field5 as f2, field3.field5[0] as f3 from table3")
+ >>> srdd6.collect() == [
+ ... {"f1": "row1", "f2": None, "f3": None},
+ ... {"f1": None, "f2": [10, 11], "f3": 10},
+ ... {"f1": "row3", "f2": [], "f3": None}]
+ True
"""
def func(split, iterator):
for x in iterator:
@@ -186,7 +717,11 @@ class SQLContext:
keyed = PipelinedRDD(rdd, func)
keyed._bypass_serializer = True
jrdd = keyed._jrdd.map(self._jvm.BytesToString())
- jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
+ if schema is None:
+ jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
+ else:
+ scala_datatype = self._ssql_ctx.parseDataType(schema.__repr__())
+ jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
return SchemaRDD(jschema_rdd, self)
def sql(self, sqlQuery):
@@ -389,6 +924,10 @@ class SchemaRDD(RDD):
"""Creates a new table with the contents of this SchemaRDD."""
self._jschema_rdd.saveAsTable(tableName)
+ def schema(self):
+ """Returns the schema of this SchemaRDD (represented by a L{StructType})."""
+ return _parse_datatype_string(self._jschema_rdd.schema().toString())
+
def schemaString(self):
"""Returns the output schema in the tree format."""
return self._jschema_rdd.schemaString()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 5a55be1e51..0d26b52a84 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -85,6 +85,26 @@ object ScalaReflection {
case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
}
+ def typeOfObject: PartialFunction[Any, DataType] = {
+ // The data type can be determined without ambiguity.
+ case obj: BooleanType.JvmType => BooleanType
+ case obj: BinaryType.JvmType => BinaryType
+ case obj: StringType.JvmType => StringType
+ case obj: ByteType.JvmType => ByteType
+ case obj: ShortType.JvmType => ShortType
+ case obj: IntegerType.JvmType => IntegerType
+ case obj: LongType.JvmType => LongType
+ case obj: FloatType.JvmType => FloatType
+ case obj: DoubleType.JvmType => DoubleType
+ case obj: DecimalType.JvmType => DecimalType
+ case obj: TimestampType.JvmType => TimestampType
+ case null => NullType
+ // For other cases, there is no obvious mapping from the type of the given object to a
+ // Catalyst data type. A user should provide his/her specific rules
+ // (in a user-defined PartialFunction) to infer the Catalyst data type for other types of
+ // objects and then compose the user-defined PartialFunction with this one.
+ }
+
implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) {
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index a3ebec8082..f38f99569f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -17,14 +17,11 @@
package org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.Logging
import org.apache.spark.sql.catalyst.errors.attachTree
-import org.apache.spark.sql.catalyst.plans.QueryPlan
-import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.trees
-import org.apache.spark.sql.Logging
-
/**
* A bound reference points to a specific slot in the input tuple, allowing the actual value
* to be retrieved more efficiently. However, since operations like column pruning can change
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
index 7470cb861b..c9a63e201e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
@@ -32,6 +32,16 @@ object Row {
* }}}
*/
def unapplySeq(row: Row): Some[Seq[Any]] = Some(row)
+
+ /**
+ * This method can be used to construct a [[Row]] with the given values.
+ */
+ def apply(values: Any*): Row = new GenericRow(values.toArray)
+
+ /**
+ * This method can be used to construct a [[Row]] from a [[Seq]] of values.
+ */
+ def fromSeq(values: Seq[Any]): Row = new GenericRow(values.toArray)
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
index e787c59e75..eb8898900d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
@@ -21,8 +21,16 @@ import scala.language.dynamics
import org.apache.spark.sql.catalyst.types.DataType
-case object DynamicType extends DataType
+/**
+ * The data type representing [[DynamicRow]] values.
+ */
+case object DynamicType extends DataType {
+ def simpleString: String = "dynamic"
+}
+/**
+ * Wrap a [[Row]] as a [[DynamicRow]].
+ */
case class WrapDynamic(children: Seq[Attribute]) extends Expression {
type EvaluatedType = DynamicRow
@@ -37,6 +45,11 @@ case class WrapDynamic(children: Seq[Attribute]) extends Expression {
}
}
+/**
+ * DynamicRows use scala's Dynamic trait to emulate an ORM of in a dynamically typed language.
+ * Since the type of the column is not known at compile time, all attributes are converted to
+ * strings before being passed to the function.
+ */
class DynamicRow(val schema: Seq[Attribute], values: Array[Any])
extends GenericRow(values) with Dynamic {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
index 0acb29012f..72add5e20e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
@@ -31,8 +31,8 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
override def foldable = child.foldable && ordinal.foldable
override def references = children.flatMap(_.references).toSet
def dataType = child.dataType match {
- case ArrayType(dt) => dt
- case MapType(_, vt) => vt
+ case ArrayType(dt, _) => dt
+ case MapType(_, vt, _) => vt
}
override lazy val resolved =
childrenResolved &&
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index dd78614754..422839dab7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -84,8 +84,8 @@ case class Explode(attributeNames: Seq[String], child: Expression)
(child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType])
private lazy val elementTypes = child.dataType match {
- case ArrayType(et) => et :: Nil
- case MapType(kt,vt) => kt :: vt :: Nil
+ case ArrayType(et, _) => et :: Nil
+ case MapType(kt,vt, _) => kt :: vt :: Nil
}
// TODO: Move this pattern into Generator.
@@ -102,10 +102,10 @@ case class Explode(attributeNames: Seq[String], child: Expression)
override def eval(input: Row): TraversableOnce[Row] = {
child.dataType match {
- case ArrayType(_) =>
+ case ArrayType(_, _) =>
val inputArray = child.eval(input).asInstanceOf[Seq[Any]]
if (inputArray == null) Nil else inputArray.map(v => new GenericRow(Array(v)))
- case MapType(_, _) =>
+ case MapType(_, _, _) =>
val inputMap = child.eval(input).asInstanceOf[Map[Any,Any]]
if (inputMap == null) Nil else inputMap.map { case (k,v) => new GenericRow(Array(k,v)) }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala
index 3b3e206055..ca9642954e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala
@@ -24,4 +24,6 @@ package object catalyst {
* 2.10.* builds. See SI-6240 for more details.
*/
protected[catalyst] object ScalaReflectionLock
+
+ protected[catalyst] type Logging = com.typesafe.scalalogging.slf4j.Logging
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala
index 67833664b3..781ba489b4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.planning
-import org.apache.spark.sql.Logging
+import org.apache.spark.sql.catalyst.Logging
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.TreeNode
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 418f8686bf..bc763a4e06 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -19,9 +19,8 @@ package org.apache.spark.sql.catalyst.planning
import scala.annotation.tailrec
-import org.apache.spark.sql.Logging
-
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.Logging
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 7b82e19b2e..0988b0c6d9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -125,51 +125,10 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
}.toSeq
}
- protected def generateSchemaString(schema: Seq[Attribute]): String = {
- val builder = new StringBuilder
- builder.append("root\n")
- val prefix = " |"
- schema.foreach { attribute =>
- val name = attribute.name
- val dataType = attribute.dataType
- dataType match {
- case fields: StructType =>
- builder.append(s"$prefix-- $name: $StructType\n")
- generateSchemaString(fields, s"$prefix |", builder)
- case ArrayType(fields: StructType) =>
- builder.append(s"$prefix-- $name: $ArrayType[$StructType]\n")
- generateSchemaString(fields, s"$prefix |", builder)
- case ArrayType(elementType: DataType) =>
- builder.append(s"$prefix-- $name: $ArrayType[$elementType]\n")
- case _ => builder.append(s"$prefix-- $name: $dataType\n")
- }
- }
-
- builder.toString()
- }
-
- protected def generateSchemaString(
- schema: StructType,
- prefix: String,
- builder: StringBuilder): StringBuilder = {
- schema.fields.foreach {
- case StructField(name, fields: StructType, _) =>
- builder.append(s"$prefix-- $name: $StructType\n")
- generateSchemaString(fields, s"$prefix |", builder)
- case StructField(name, ArrayType(fields: StructType), _) =>
- builder.append(s"$prefix-- $name: $ArrayType[$StructType]\n")
- generateSchemaString(fields, s"$prefix |", builder)
- case StructField(name, ArrayType(elementType: DataType), _) =>
- builder.append(s"$prefix-- $name: $ArrayType[$elementType]\n")
- case StructField(name, fieldType: DataType, _) =>
- builder.append(s"$prefix-- $name: $fieldType\n")
- }
-
- builder
- }
+ def schema: StructType = StructType.fromAttributes(output)
/** Returns the output schema in the tree format. */
- def schemaString: String = generateSchemaString(output)
+ def schemaString: String = schema.treeString
/** Prints out the schema in the tree format */
def printSchema(): Unit = println(schemaString)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 1537de259c..3cb407217c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -177,7 +177,7 @@ case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode {
case StructType(fields) =>
StructType(fields.map(f =>
StructField(f.name.toLowerCase(), lowerCaseSchema(f.dataType), f.nullable)))
- case ArrayType(elemType) => ArrayType(lowerCaseSchema(elemType))
+ case ArrayType(elemType, containsNull) => ArrayType(lowerCaseSchema(elemType), containsNull)
case otherType => otherType
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala
index 1076537bc7..f8960b3fe7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.rules
-import org.apache.spark.sql.Logging
+import org.apache.spark.sql.catalyst.Logging
import org.apache.spark.sql.catalyst.trees.TreeNode
abstract class Rule[TreeType <: TreeNode[_]] extends Logging {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
index e300bdbece..6aa407c836 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
@@ -15,10 +15,9 @@
* limitations under the License.
*/
-package org.apache.spark.sql
-package catalyst
-package rules
+package org.apache.spark.sql.catalyst.rules
+import org.apache.spark.sql.catalyst.Logging
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.util.sideBySide
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala
index d159ecdd5d..9a28d035a1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst
-import org.apache.spark.sql.Logger
-
/**
* A library for easily manipulating trees of operators. Operators that extend TreeNode are
* granted the following interface:
@@ -35,5 +33,6 @@ import org.apache.spark.sql.Logger
*/
package object trees {
// Since we want tree nodes to be lightweight, we create one logger for all treenode instances.
- protected val logger = Logger("catalyst.trees")
+ protected val logger =
+ com.typesafe.scalalogging.slf4j.Logger(org.slf4j.LoggerFactory.getLogger("catalyst.trees"))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
index 71808f76d6..b52ee6d337 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
@@ -45,11 +45,13 @@ object DataType extends RegexParsers {
"TimestampType" ^^^ TimestampType
protected lazy val arrayType: Parser[DataType] =
- "ArrayType" ~> "(" ~> dataType <~ ")" ^^ ArrayType
+ "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ {
+ case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull)
+ }
protected lazy val mapType: Parser[DataType] =
- "MapType" ~> "(" ~> dataType ~ "," ~ dataType <~ ")" ^^ {
- case t1 ~ _ ~ t2 => MapType(t1, t2)
+ "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ {
+ case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull)
}
protected lazy val structField: Parser[StructField] =
@@ -82,6 +84,21 @@ object DataType extends RegexParsers {
case Success(result, _) => result
case failure: NoSuccess => sys.error(s"Unsupported dataType: $asString, $failure")
}
+
+ protected[types] def buildFormattedString(
+ dataType: DataType,
+ prefix: String,
+ builder: StringBuilder): Unit = {
+ dataType match {
+ case array: ArrayType =>
+ array.buildFormattedString(prefix, builder)
+ case struct: StructType =>
+ struct.buildFormattedString(prefix, builder)
+ case map: MapType =>
+ map.buildFormattedString(prefix, builder)
+ case _ =>
+ }
+ }
}
abstract class DataType {
@@ -92,9 +109,13 @@ abstract class DataType {
}
def isPrimitive: Boolean = false
+
+ def simpleString: String
}
-case object NullType extends DataType
+case object NullType extends DataType {
+ def simpleString: String = "null"
+}
object NativeType {
def all = Seq(
@@ -108,40 +129,45 @@ trait PrimitiveType extends DataType {
}
abstract class NativeType extends DataType {
- type JvmType
- @transient val tag: TypeTag[JvmType]
- val ordering: Ordering[JvmType]
+ private[sql] type JvmType
+ @transient private[sql] val tag: TypeTag[JvmType]
+ private[sql] val ordering: Ordering[JvmType]
- @transient val classTag = ScalaReflectionLock.synchronized {
+ @transient private[sql] val classTag = ScalaReflectionLock.synchronized {
val mirror = runtimeMirror(Utils.getSparkClassLoader)
ClassTag[JvmType](mirror.runtimeClass(tag.tpe))
}
}
case object StringType extends NativeType with PrimitiveType {
- type JvmType = String
- @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
- val ordering = implicitly[Ordering[JvmType]]
+ private[sql] type JvmType = String
+ @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
+ private[sql] val ordering = implicitly[Ordering[JvmType]]
+ def simpleString: String = "string"
}
case object BinaryType extends DataType with PrimitiveType {
- type JvmType = Array[Byte]
+ private[sql] type JvmType = Array[Byte]
+ def simpleString: String = "binary"
}
case object BooleanType extends NativeType with PrimitiveType {
- type JvmType = Boolean
- @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
- val ordering = implicitly[Ordering[JvmType]]
+ private[sql] type JvmType = Boolean
+ @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
+ private[sql] val ordering = implicitly[Ordering[JvmType]]
+ def simpleString: String = "boolean"
}
case object TimestampType extends NativeType {
- type JvmType = Timestamp
+ private[sql] type JvmType = Timestamp
- @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
+ @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
- val ordering = new Ordering[JvmType] {
+ private[sql] val ordering = new Ordering[JvmType] {
def compare(x: Timestamp, y: Timestamp) = x.compareTo(y)
}
+
+ def simpleString: String = "timestamp"
}
abstract class NumericType extends NativeType with PrimitiveType {
@@ -150,7 +176,7 @@ abstract class NumericType extends NativeType with PrimitiveType {
// type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets
// desugared by the compiler into an argument to the objects constructor. This means there is no
// longer an no argument constructor and thus the JVM cannot serialize the object anymore.
- val numeric: Numeric[JvmType]
+ private[sql] val numeric: Numeric[JvmType]
}
object NumericType {
@@ -166,39 +192,43 @@ object IntegralType {
}
abstract class IntegralType extends NumericType {
- val integral: Integral[JvmType]
+ private[sql] val integral: Integral[JvmType]
}
case object LongType extends IntegralType {
- type JvmType = Long
- @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
- val numeric = implicitly[Numeric[Long]]
- val integral = implicitly[Integral[Long]]
- val ordering = implicitly[Ordering[JvmType]]
+ private[sql] type JvmType = Long
+ @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
+ private[sql] val numeric = implicitly[Numeric[Long]]
+ private[sql] val integral = implicitly[Integral[Long]]
+ private[sql] val ordering = implicitly[Ordering[JvmType]]
+ def simpleString: String = "long"
}
case object IntegerType extends IntegralType {
- type JvmType = Int
- @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
- val numeric = implicitly[Numeric[Int]]
- val integral = implicitly[Integral[Int]]
- val ordering = implicitly[Ordering[JvmType]]
+ private[sql] type JvmType = Int
+ @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
+ private[sql] val numeric = implicitly[Numeric[Int]]
+ private[sql] val integral = implicitly[Integral[Int]]
+ private[sql] val ordering = implicitly[Ordering[JvmType]]
+ def simpleString: String = "integer"
}
case object ShortType extends IntegralType {
- type JvmType = Short
- @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
- val numeric = implicitly[Numeric[Short]]
- val integral = implicitly[Integral[Short]]
- val ordering = implicitly[Ordering[JvmType]]
+ private[sql] type JvmType = Short
+ @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
+ private[sql] val numeric = implicitly[Numeric[Short]]
+ private[sql] val integral = implicitly[Integral[Short]]
+ private[sql] val ordering = implicitly[Ordering[JvmType]]
+ def simpleString: String = "short"
}
case object ByteType extends IntegralType {
- type JvmType = Byte
- @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
- val numeric = implicitly[Numeric[Byte]]
- val integral = implicitly[Integral[Byte]]
- val ordering = implicitly[Ordering[JvmType]]
+ private[sql] type JvmType = Byte
+ @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
+ private[sql] val numeric = implicitly[Numeric[Byte]]
+ private[sql] val integral = implicitly[Integral[Byte]]
+ private[sql] val ordering = implicitly[Ordering[JvmType]]
+ def simpleString: String = "byte"
}
/** Matcher for any expressions that evaluate to [[FractionalType]]s */
@@ -209,47 +239,159 @@ object FractionalType {
}
}
abstract class FractionalType extends NumericType {
- val fractional: Fractional[JvmType]
+ private[sql] val fractional: Fractional[JvmType]
}
case object DecimalType extends FractionalType {
- type JvmType = BigDecimal
- @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
- val numeric = implicitly[Numeric[BigDecimal]]
- val fractional = implicitly[Fractional[BigDecimal]]
- val ordering = implicitly[Ordering[JvmType]]
+ private[sql] type JvmType = BigDecimal
+ @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
+ private[sql] val numeric = implicitly[Numeric[BigDecimal]]
+ private[sql] val fractional = implicitly[Fractional[BigDecimal]]
+ private[sql] val ordering = implicitly[Ordering[JvmType]]
+ def simpleString: String = "decimal"
}
case object DoubleType extends FractionalType {
- type JvmType = Double
- @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
- val numeric = implicitly[Numeric[Double]]
- val fractional = implicitly[Fractional[Double]]
- val ordering = implicitly[Ordering[JvmType]]
+ private[sql] type JvmType = Double
+ @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
+ private[sql] val numeric = implicitly[Numeric[Double]]
+ private[sql] val fractional = implicitly[Fractional[Double]]
+ private[sql] val ordering = implicitly[Ordering[JvmType]]
+ def simpleString: String = "double"
}
case object FloatType extends FractionalType {
- type JvmType = Float
- @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
- val numeric = implicitly[Numeric[Float]]
- val fractional = implicitly[Fractional[Float]]
- val ordering = implicitly[Ordering[JvmType]]
+ private[sql] type JvmType = Float
+ @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
+ private[sql] val numeric = implicitly[Numeric[Float]]
+ private[sql] val fractional = implicitly[Fractional[Float]]
+ private[sql] val ordering = implicitly[Ordering[JvmType]]
+ def simpleString: String = "float"
}
-case class ArrayType(elementType: DataType) extends DataType
+object ArrayType {
+ /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is false. */
+ def apply(elementType: DataType): ArrayType = ArrayType(elementType, false)
+}
-case class StructField(name: String, dataType: DataType, nullable: Boolean)
+/**
+ * The data type for collections of multiple values.
+ * Internally these are represented as columns that contain a ``scala.collection.Seq``.
+ *
+ * @param elementType The data type of values.
+ * @param containsNull Indicates if values have `null` values
+ */
+case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType {
+ private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
+ builder.append(
+ s"${prefix}-- element: ${elementType.simpleString} (containsNull = ${containsNull})\n")
+ DataType.buildFormattedString(elementType, s"$prefix |", builder)
+ }
+
+ def simpleString: String = "array"
+}
+
+/**
+ * A field inside a StructType.
+ * @param name The name of this field.
+ * @param dataType The data type of this field.
+ * @param nullable Indicates if values of this field can be `null` values.
+ */
+case class StructField(name: String, dataType: DataType, nullable: Boolean) {
+
+ private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
+ builder.append(s"${prefix}-- ${name}: ${dataType.simpleString} (nullable = ${nullable})\n")
+ DataType.buildFormattedString(dataType, s"$prefix |", builder)
+ }
+}
object StructType {
- def fromAttributes(attributes: Seq[Attribute]): StructType = {
+ protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType =
StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable)))
- }
- // def apply(fields: Seq[StructField]) = new StructType(fields.toIndexedSeq)
+ private def validateFields(fields: Seq[StructField]): Boolean =
+ fields.map(field => field.name).distinct.size == fields.size
}
case class StructType(fields: Seq[StructField]) extends DataType {
- def toAttributes = fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)())
+ require(StructType.validateFields(fields), "Found fields with the same name.")
+
+ /**
+ * Returns all field names in a [[Seq]].
+ */
+ lazy val fieldNames: Seq[String] = fields.map(_.name)
+ private lazy val fieldNamesSet: Set[String] = fieldNames.toSet
+ private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap
+ /**
+ * Extracts a [[StructField]] of the given name. If the [[StructType]] object does not
+ * have a name matching the given name, `null` will be returned.
+ */
+ def apply(name: String): StructField = {
+ nameToField.get(name).getOrElse(
+ throw new IllegalArgumentException(s"Field ${name} does not exist."))
+ }
+
+ /**
+ * Returns a [[StructType]] containing [[StructField]]s of the given names.
+ * Those names which do not have matching fields will be ignored.
+ */
+ def apply(names: Set[String]): StructType = {
+ val nonExistFields = names -- fieldNamesSet
+ if (!nonExistFields.isEmpty) {
+ throw new IllegalArgumentException(
+ s"Field ${nonExistFields.mkString(",")} does not exist.")
+ }
+ // Preserve the original order of fields.
+ StructType(fields.filter(f => names.contains(f.name)))
+ }
+
+ protected[sql] def toAttributes =
+ fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)())
+
+ def treeString: String = {
+ val builder = new StringBuilder
+ builder.append("root\n")
+ val prefix = " |"
+ fields.foreach(field => field.buildFormattedString(prefix, builder))
+
+ builder.toString()
+ }
+
+ def printTreeString(): Unit = println(treeString)
+
+ private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
+ fields.foreach(field => field.buildFormattedString(prefix, builder))
+ }
+
+ def simpleString: String = "struct"
+}
+
+object MapType {
+ /**
+ * Construct a [[MapType]] object with the given key type and value type.
+ * The `valueContainsNull` is true.
+ */
+ def apply(keyType: DataType, valueType: DataType): MapType =
+ MapType(keyType: DataType, valueType: DataType, true)
}
-case class MapType(keyType: DataType, valueType: DataType) extends DataType
+/**
+ * The data type for Maps. Keys in a map are not allowed to have `null` values.
+ * @param keyType The data type of map keys.
+ * @param valueType The data type of map values.
+ * @param valueContainsNull Indicates if map values have `null` values.
+ */
+case class MapType(
+ keyType: DataType,
+ valueType: DataType,
+ valueContainsNull: Boolean) extends DataType {
+ private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
+ builder.append(s"${prefix}-- key: ${keyType.simpleString}\n")
+ builder.append(s"${prefix}-- value: ${valueType.simpleString} " +
+ s"(valueContainsNull = ${valueContainsNull})\n")
+ DataType.buildFormattedString(keyType, s"$prefix |", builder)
+ DataType.buildFormattedString(valueType, s"$prefix |", builder)
+ }
+
+ def simpleString: String = "map"
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index c0438dbe52..e030d6e13d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -17,11 +17,11 @@
package org.apache.spark.sql.catalyst
+import java.math.BigInteger
import java.sql.Timestamp
import org.scalatest.FunSuite
-import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types._
case class PrimitiveData(
@@ -148,4 +148,68 @@ class ScalaReflectionSuite extends FunSuite {
StructField("_2", StringType, nullable = true))),
nullable = true))
}
+
+ test("get data type of a value") {
+ // BooleanType
+ assert(BooleanType === typeOfObject(true))
+ assert(BooleanType === typeOfObject(false))
+
+ // BinaryType
+ assert(BinaryType === typeOfObject("string".getBytes))
+
+ // StringType
+ assert(StringType === typeOfObject("string"))
+
+ // ByteType
+ assert(ByteType === typeOfObject(127.toByte))
+
+ // ShortType
+ assert(ShortType === typeOfObject(32767.toShort))
+
+ // IntegerType
+ assert(IntegerType === typeOfObject(2147483647))
+
+ // LongType
+ assert(LongType === typeOfObject(9223372036854775807L))
+
+ // FloatType
+ assert(FloatType === typeOfObject(3.4028235E38.toFloat))
+
+ // DoubleType
+ assert(DoubleType === typeOfObject(1.7976931348623157E308))
+
+ // DecimalType
+ assert(DecimalType === typeOfObject(BigDecimal("1.7976931348623157E318")))
+
+ // TimestampType
+ assert(TimestampType === typeOfObject(java.sql.Timestamp.valueOf("2014-7-25 10:26:00")))
+
+ // NullType
+ assert(NullType === typeOfObject(null))
+
+ def typeOfObject1: PartialFunction[Any, DataType] = typeOfObject orElse {
+ case value: java.math.BigInteger => DecimalType
+ case value: java.math.BigDecimal => DecimalType
+ case _ => StringType
+ }
+
+ assert(DecimalType === typeOfObject1(
+ new BigInteger("92233720368547758070")))
+ assert(DecimalType === typeOfObject1(
+ new java.math.BigDecimal("1.7976931348623157E318")))
+ assert(StringType === typeOfObject1(BigInt("92233720368547758070")))
+
+ def typeOfObject2: PartialFunction[Any, DataType] = typeOfObject orElse {
+ case value: java.math.BigInteger => DecimalType
+ }
+
+ intercept[MatchError](typeOfObject2(BigInt("92233720368547758070")))
+
+ def typeOfObject3: PartialFunction[Any, DataType] = typeOfObject orElse {
+ case c: Seq[_] => ArrayType(typeOfObject3(c.head))
+ }
+
+ assert(ArrayType(IntegerType) === typeOfObject3(Seq(1, 2, 3)))
+ assert(ArrayType(ArrayType(IntegerType)) === typeOfObject3(Seq(Seq(1,2,3))))
+ }
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ArrayType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ArrayType.java
new file mode 100644
index 0000000000..17334ca31b
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ArrayType.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java.types;
+
+/**
+ * The data type representing Lists.
+ * An ArrayType object comprises two fields, {@code DataType elementType} and
+ * {@code boolean containsNull}. The field of {@code elementType} is used to specify the type of
+ * array elements. The field of {@code containsNull} is used to specify if the array has
+ * {@code null} values.
+ *
+ * To create an {@link ArrayType},
+ * {@link org.apache.spark.sql.api.java.types.DataType#createArrayType(DataType)} or
+ * {@link org.apache.spark.sql.api.java.types.DataType#createArrayType(DataType, boolean)}
+ * should be used.
+ */
+public class ArrayType extends DataType {
+ private DataType elementType;
+ private boolean containsNull;
+
+ protected ArrayType(DataType elementType, boolean containsNull) {
+ this.elementType = elementType;
+ this.containsNull = containsNull;
+ }
+
+ public DataType getElementType() {
+ return elementType;
+ }
+
+ public boolean isContainsNull() {
+ return containsNull;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+
+ ArrayType arrayType = (ArrayType) o;
+
+ if (containsNull != arrayType.containsNull) return false;
+ if (!elementType.equals(arrayType.elementType)) return false;
+
+ return true;
+ }
+
+ @Override
+ public int hashCode() {
+ int result = elementType.hashCode();
+ result = 31 * result + (containsNull ? 1 : 0);
+ return result;
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BinaryType.java
index 4589129cd1..6170317985 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BinaryType.java
@@ -15,22 +15,13 @@
* limitations under the License.
*/
-package org.apache.spark
+package org.apache.spark.sql.api.java.types;
/**
- * Allows the execution of relational queries, including those expressed in SQL using Spark.
+ * The data type representing byte[] values.
*
- * Note that this package is located in catalyst instead of in core so that all subprojects can
- * inherit the settings from this package object.
+ * {@code BinaryType} is represented by the singleton object {@link DataType#BinaryType}.
*/
-package object sql {
-
- protected[sql] def Logger(name: String) =
- com.typesafe.scalalogging.slf4j.Logger(org.slf4j.LoggerFactory.getLogger(name))
-
- protected[sql] type Logging = com.typesafe.scalalogging.slf4j.Logging
-
- type Row = catalyst.expressions.Row
-
- val Row = catalyst.expressions.Row
+public class BinaryType extends DataType {
+ protected BinaryType() {}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BooleanType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BooleanType.java
new file mode 100644
index 0000000000..8fa24d85d1
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BooleanType.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java.types;
+
+/**
+ * The data type representing boolean and Boolean values.
+ *
+ * {@code BooleanType} is represented by the singleton object {@link DataType#BooleanType}.
+ */
+public class BooleanType extends DataType {
+ protected BooleanType() {}
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ByteType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ByteType.java
new file mode 100644
index 0000000000..2de32978e2
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ByteType.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java.types;
+
+/**
+ * The data type representing byte and Byte values.
+ *
+ * {@code ByteType} is represented by the singleton object {@link DataType#ByteType}.
+ */
+public class ByteType extends DataType {
+ protected ByteType() {}
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java
new file mode 100644
index 0000000000..f84e5a490a
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java.types;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+/**
+ * The base type of all Spark SQL data types.
+ *
+ * To get/create specific data type, users should use singleton objects and factory methods
+ * provided by this class.
+ */
+public abstract class DataType {
+
+ /**
+ * Gets the StringType object.
+ */
+ public static final StringType StringType = new StringType();
+
+ /**
+ * Gets the BinaryType object.
+ */
+ public static final BinaryType BinaryType = new BinaryType();
+
+ /**
+ * Gets the BooleanType object.
+ */
+ public static final BooleanType BooleanType = new BooleanType();
+
+ /**
+ * Gets the TimestampType object.
+ */
+ public static final TimestampType TimestampType = new TimestampType();
+
+ /**
+ * Gets the DecimalType object.
+ */
+ public static final DecimalType DecimalType = new DecimalType();
+
+ /**
+ * Gets the DoubleType object.
+ */
+ public static final DoubleType DoubleType = new DoubleType();
+
+ /**
+ * Gets the FloatType object.
+ */
+ public static final FloatType FloatType = new FloatType();
+
+ /**
+ * Gets the ByteType object.
+ */
+ public static final ByteType ByteType = new ByteType();
+
+ /**
+ * Gets the IntegerType object.
+ */
+ public static final IntegerType IntegerType = new IntegerType();
+
+ /**
+ * Gets the LongType object.
+ */
+ public static final LongType LongType = new LongType();
+
+ /**
+ * Gets the ShortType object.
+ */
+ public static final ShortType ShortType = new ShortType();
+
+ /**
+ * Creates an ArrayType by specifying the data type of elements ({@code elementType}).
+ * The field of {@code containsNull} is set to {@code false}.
+ */
+ public static ArrayType createArrayType(DataType elementType) {
+ if (elementType == null) {
+ throw new IllegalArgumentException("elementType should not be null.");
+ }
+
+ return new ArrayType(elementType, false);
+ }
+
+ /**
+ * Creates an ArrayType by specifying the data type of elements ({@code elementType}) and
+ * whether the array contains null values ({@code containsNull}).
+ */
+ public static ArrayType createArrayType(DataType elementType, boolean containsNull) {
+ if (elementType == null) {
+ throw new IllegalArgumentException("elementType should not be null.");
+ }
+
+ return new ArrayType(elementType, containsNull);
+ }
+
+ /**
+ * Creates a MapType by specifying the data type of keys ({@code keyType}) and values
+ * ({@code keyType}). The field of {@code valueContainsNull} is set to {@code true}.
+ */
+ public static MapType createMapType(DataType keyType, DataType valueType) {
+ if (keyType == null) {
+ throw new IllegalArgumentException("keyType should not be null.");
+ }
+ if (valueType == null) {
+ throw new IllegalArgumentException("valueType should not be null.");
+ }
+
+ return new MapType(keyType, valueType, true);
+ }
+
+ /**
+ * Creates a MapType by specifying the data type of keys ({@code keyType}), the data type of
+ * values ({@code keyType}), and whether values contain any null value
+ * ({@code valueContainsNull}).
+ */
+ public static MapType createMapType(
+ DataType keyType,
+ DataType valueType,
+ boolean valueContainsNull) {
+ if (keyType == null) {
+ throw new IllegalArgumentException("keyType should not be null.");
+ }
+ if (valueType == null) {
+ throw new IllegalArgumentException("valueType should not be null.");
+ }
+
+ return new MapType(keyType, valueType, valueContainsNull);
+ }
+
+ /**
+ * Creates a StructField by specifying the name ({@code name}), data type ({@code dataType}) and
+ * whether values of this field can be null values ({@code nullable}).
+ */
+ public static StructField createStructField(String name, DataType dataType, boolean nullable) {
+ if (name == null) {
+ throw new IllegalArgumentException("name should not be null.");
+ }
+ if (dataType == null) {
+ throw new IllegalArgumentException("dataType should not be null.");
+ }
+
+ return new StructField(name, dataType, nullable);
+ }
+
+ /**
+ * Creates a StructType with the given list of StructFields ({@code fields}).
+ */
+ public static StructType createStructType(List<StructField> fields) {
+ return createStructType(fields.toArray(new StructField[0]));
+ }
+
+ /**
+ * Creates a StructType with the given StructField array ({@code fields}).
+ */
+ public static StructType createStructType(StructField[] fields) {
+ if (fields == null) {
+ throw new IllegalArgumentException("fields should not be null.");
+ }
+ Set<String> distinctNames = new HashSet<String>();
+ for (StructField field: fields) {
+ if (field == null) {
+ throw new IllegalArgumentException(
+ "fields should not contain any null.");
+ }
+
+ distinctNames.add(field.getName());
+ }
+ if (distinctNames.size() != fields.length) {
+ throw new IllegalArgumentException("fields should have distinct names.");
+ }
+
+ return new StructType(fields);
+ }
+
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DecimalType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DecimalType.java
new file mode 100644
index 0000000000..9250491a2d
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DecimalType.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java.types;
+
+/**
+ * The data type representing java.math.BigDecimal values.
+ *
+ * {@code DecimalType} is represented by the singleton object {@link DataType#DecimalType}.
+ */
+public class DecimalType extends DataType {
+ protected DecimalType() {}
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DoubleType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DoubleType.java
new file mode 100644
index 0000000000..3e86917fdd
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DoubleType.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java.types;
+
+/**
+ * The data type representing double and Double values.
+ *
+ * {@code DoubleType} is represented by the singleton object {@link DataType#DoubleType}.
+ */
+public class DoubleType extends DataType {
+ protected DoubleType() {}
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/FloatType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/FloatType.java
new file mode 100644
index 0000000000..fa860d4017
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/FloatType.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java.types;
+
+/**
+ * The data type representing float and Float values.
+ *
+ * {@code FloatType} is represented by the singleton object {@link DataType#FloatType}.
+ */
+public class FloatType extends DataType {
+ protected FloatType() {}
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/IntegerType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/IntegerType.java
new file mode 100644
index 0000000000..bd973eca2c
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/IntegerType.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java.types;
+
+/**
+ * The data type representing int and Integer values.
+ *
+ * {@code IntegerType} is represented by the singleton object {@link DataType#IntegerType}.
+ */
+public class IntegerType extends DataType {
+ protected IntegerType() {}
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/LongType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/LongType.java
new file mode 100644
index 0000000000..e00233304c
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/LongType.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java.types;
+
+/**
+ * The data type representing long and Long values.
+ *
+ * {@code LongType} is represented by the singleton object {@link DataType#LongType}.
+ */
+public class LongType extends DataType {
+ protected LongType() {}
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java
new file mode 100644
index 0000000000..94936e2e4e
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java.types;
+
+/**
+ * The data type representing Maps. A MapType object comprises two fields,
+ * {@code DataType keyType}, {@code DataType valueType}, and {@code boolean valueContainsNull}.
+ * The field of {@code keyType} is used to specify the type of keys in the map.
+ * The field of {@code valueType} is used to specify the type of values in the map.
+ * The field of {@code valueContainsNull} is used to specify if map values have
+ * {@code null} values.
+ * For values of a MapType column, keys are not allowed to have {@code null} values.
+ *
+ * To create a {@link MapType},
+ * {@link org.apache.spark.sql.api.java.types.DataType#createMapType(DataType, DataType)} or
+ * {@link org.apache.spark.sql.api.java.types.DataType#createMapType(DataType, DataType, boolean)}
+ * should be used.
+ */
+public class MapType extends DataType {
+ private DataType keyType;
+ private DataType valueType;
+ private boolean valueContainsNull;
+
+ protected MapType(DataType keyType, DataType valueType, boolean valueContainsNull) {
+ this.keyType = keyType;
+ this.valueType = valueType;
+ this.valueContainsNull = valueContainsNull;
+ }
+
+ public DataType getKeyType() {
+ return keyType;
+ }
+
+ public DataType getValueType() {
+ return valueType;
+ }
+
+ public boolean isValueContainsNull() {
+ return valueContainsNull;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+
+ MapType mapType = (MapType) o;
+
+ if (valueContainsNull != mapType.valueContainsNull) return false;
+ if (!keyType.equals(mapType.keyType)) return false;
+ if (!valueType.equals(mapType.valueType)) return false;
+
+ return true;
+ }
+
+ @Override
+ public int hashCode() {
+ int result = keyType.hashCode();
+ result = 31 * result + valueType.hashCode();
+ result = 31 * result + (valueContainsNull ? 1 : 0);
+ return result;
+ }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ShortType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ShortType.java
new file mode 100644
index 0000000000..98f9507acf
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ShortType.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java.types;
+
+/**
+ * The data type representing short and Short values.
+ *
+ * {@code ShortType} is represented by the singleton object {@link DataType#ShortType}.
+ */
+public class ShortType extends DataType {
+ protected ShortType() {}
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StringType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StringType.java
new file mode 100644
index 0000000000..b8e7dbe646
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StringType.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java.types;
+
+/**
+ * The data type representing String values.
+ *
+ * {@code StringType} is represented by the singleton object {@link DataType#StringType}.
+ */
+public class StringType extends DataType {
+ protected StringType() {}
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructField.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructField.java
new file mode 100644
index 0000000000..54e9c11ea4
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructField.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java.types;
+
+/**
+ * A StructField object represents a field in a StructType object.
+ * A StructField object comprises three fields, {@code String name}, {@code DataType dataType},
+ * and {@code boolean nullable}. The field of {@code name} is the name of a StructField.
+ * The field of {@code dataType} specifies the data type of a StructField.
+ * The field of {@code nullable} specifies if values of a StructField can contain {@code null}
+ * values.
+ *
+ * To create a {@link StructField},
+ * {@link org.apache.spark.sql.api.java.types.DataType#createStructField(String, DataType, boolean)}
+ * should be used.
+ */
+public class StructField {
+ private String name;
+ private DataType dataType;
+ private boolean nullable;
+
+ protected StructField(String name, DataType dataType, boolean nullable) {
+ this.name = name;
+ this.dataType = dataType;
+ this.nullable = nullable;
+ }
+
+ public String getName() {
+ return name;
+ }
+
+ public DataType getDataType() {
+ return dataType;
+ }
+
+ public boolean isNullable() {
+ return nullable;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+
+ StructField that = (StructField) o;
+
+ if (nullable != that.nullable) return false;
+ if (!dataType.equals(that.dataType)) return false;
+ if (!name.equals(that.name)) return false;
+
+ return true;
+ }
+
+ @Override
+ public int hashCode() {
+ int result = name.hashCode();
+ result = 31 * result + dataType.hashCode();
+ result = 31 * result + (nullable ? 1 : 0);
+ return result;
+ }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructType.java
new file mode 100644
index 0000000000..33a42f4b16
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructType.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java.types;
+
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * The data type representing Rows.
+ * A StructType object comprises an array of StructFields.
+ *
+ * To create an {@link StructType},
+ * {@link org.apache.spark.sql.api.java.types.DataType#createStructType(java.util.List)} or
+ * {@link org.apache.spark.sql.api.java.types.DataType#createStructType(StructField[])}
+ * should be used.
+ */
+public class StructType extends DataType {
+ private StructField[] fields;
+
+ protected StructType(StructField[] fields) {
+ this.fields = fields;
+ }
+
+ public StructField[] getFields() {
+ return fields;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+
+ StructType that = (StructType) o;
+
+ if (!Arrays.equals(fields, that.fields)) return false;
+
+ return true;
+ }
+
+ @Override
+ public int hashCode() {
+ return Arrays.hashCode(fields);
+ }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/TimestampType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/TimestampType.java
new file mode 100644
index 0000000000..65295779f7
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/TimestampType.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java.types;
+
+/**
+ * The data type representing java.sql.Timestamp values.
+ *
+ * {@code TimestampType} is represented by the singleton object {@link DataType#TimestampType}.
+ */
+public class TimestampType extends DataType {
+ protected TimestampType() {}
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/package-info.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/package-info.java
new file mode 100644
index 0000000000..f169ac65e2
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+/**
+ * Allows users to get and create Spark SQL data types.
+ */
+package org.apache.spark.sql.api.java.types;
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index e4b6810180..86338752a2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.RuleExecutor
-import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.columnar.InMemoryRelation
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.SparkStrategies
@@ -89,6 +88,44 @@ class SQLContext(@transient val sparkContext: SparkContext)
new SchemaRDD(this, SparkLogicalPlan(ExistingRdd.fromProductRdd(rdd))(self))
/**
+ * :: DeveloperApi ::
+ * Creates a [[SchemaRDD]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD.
+ * It is important to make sure that the structure of every [[Row]] of the provided RDD matches
+ * the provided schema. Otherwise, there will be runtime exception.
+ * Example:
+ * {{{
+ * import org.apache.spark.sql._
+ * val sqlContext = new org.apache.spark.sql.SQLContext(sc)
+ *
+ * val schema =
+ * StructType(
+ * StructField("name", StringType, false) ::
+ * StructField("age", IntegerType, true) :: Nil)
+ *
+ * val people =
+ * sc.textFile("examples/src/main/resources/people.txt").map(
+ * _.split(",")).map(p => Row(p(0), p(1).trim.toInt))
+ * val peopleSchemaRDD = sqlContext. applySchema(people, schema)
+ * peopleSchemaRDD.printSchema
+ * // root
+ * // |-- name: string (nullable = false)
+ * // |-- age: integer (nullable = true)
+ *
+ * peopleSchemaRDD.registerAsTable("people")
+ * sqlContext.sql("select name from people").collect.foreach(println)
+ * }}}
+ *
+ * @group userf
+ */
+ @DeveloperApi
+ def applySchema(rowRDD: RDD[Row], schema: StructType): SchemaRDD = {
+ // TODO: use MutableProjection when rowRDD is another SchemaRDD and the applied
+ // schema differs from the existing schema on any field data type.
+ val logicalPlan = SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRDD))(self)
+ new SchemaRDD(this, logicalPlan)
+ }
+
+ /**
* Loads a Parquet file, returning the result as a [[SchemaRDD]].
*
* @group userf
@@ -106,6 +143,19 @@ class SQLContext(@transient val sparkContext: SparkContext)
/**
* :: Experimental ::
+ * Loads a JSON file (one object per line) and applies the given schema,
+ * returning the result as a [[SchemaRDD]].
+ *
+ * @group userf
+ */
+ @Experimental
+ def jsonFile(path: String, schema: StructType): SchemaRDD = {
+ val json = sparkContext.textFile(path)
+ jsonRDD(json, schema)
+ }
+
+ /**
+ * :: Experimental ::
*/
@Experimental
def jsonFile(path: String, samplingRatio: Double): SchemaRDD = {
@@ -124,10 +174,28 @@ class SQLContext(@transient val sparkContext: SparkContext)
/**
* :: Experimental ::
+ * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema,
+ * returning the result as a [[SchemaRDD]].
+ *
+ * @group userf
+ */
+ @Experimental
+ def jsonRDD(json: RDD[String], schema: StructType): SchemaRDD = {
+ val appliedSchema =
+ Option(schema).getOrElse(JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, 1.0)))
+ val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema)
+ applySchema(rowRDD, appliedSchema)
+ }
+
+ /**
+ * :: Experimental ::
*/
@Experimental
- def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD =
- new SchemaRDD(this, JsonRDD.inferSchema(self, json, samplingRatio))
+ def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = {
+ val appliedSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, samplingRatio))
+ val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema)
+ applySchema(rowRDD, appliedSchema)
+ }
/**
* :: Experimental ::
@@ -345,70 +413,138 @@ class SQLContext(@transient val sparkContext: SparkContext)
/**
* Peek at the first row of the RDD and infer its schema.
- * TODO: consolidate this with the type system developed in SPARK-2060.
+ * It is only used by PySpark.
*/
private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = {
import scala.collection.JavaConversions._
- def typeFor(obj: Any): DataType = obj match {
- case c: java.lang.String => StringType
- case c: java.lang.Integer => IntegerType
- case c: java.lang.Long => LongType
- case c: java.lang.Double => DoubleType
- case c: java.lang.Boolean => BooleanType
- case c: java.math.BigDecimal => DecimalType
- case c: java.sql.Timestamp => TimestampType
+
+ def typeOfComplexValue: PartialFunction[Any, DataType] = {
case c: java.util.Calendar => TimestampType
- case c: java.util.List[_] => ArrayType(typeFor(c.head))
+ case c: java.util.List[_] =>
+ ArrayType(typeOfObject(c.head))
case c: java.util.Map[_, _] =>
val (key, value) = c.head
- MapType(typeFor(key), typeFor(value))
+ MapType(typeOfObject(key), typeOfObject(value))
case c if c.getClass.isArray =>
val elem = c.asInstanceOf[Array[_]].head
- ArrayType(typeFor(elem))
+ ArrayType(typeOfObject(elem))
case c => throw new Exception(s"Object of type $c cannot be used")
}
+ def typeOfObject = ScalaReflection.typeOfObject orElse typeOfComplexValue
+
val firstRow = rdd.first()
- val schema = firstRow.map { case (fieldName, obj) =>
- AttributeReference(fieldName, typeFor(obj), true)()
+ val fields = firstRow.map {
+ case (fieldName, obj) => StructField(fieldName, typeOfObject(obj), true)
}.toSeq
- def needTransform(obj: Any): Boolean = obj match {
- case c: java.util.List[_] => true
- case c: java.util.Map[_, _] => true
- case c if c.getClass.isArray => true
- case c: java.util.Calendar => true
- case c => false
+ applySchemaToPythonRDD(rdd, StructType(fields))
+ }
+
+ /**
+ * Parses the data type in our internal string representation. The data type string should
+ * have the same format as the one generated by `toString` in scala.
+ * It is only used by PySpark.
+ */
+ private[sql] def parseDataType(dataTypeString: String): DataType = {
+ val parser = org.apache.spark.sql.catalyst.types.DataType
+ parser(dataTypeString)
+ }
+
+ /**
+ * Apply a schema defined by the schemaString to an RDD. It is only used by PySpark.
+ */
+ private[sql] def applySchemaToPythonRDD(
+ rdd: RDD[Map[String, _]],
+ schemaString: String): SchemaRDD = {
+ val schema = parseDataType(schemaString).asInstanceOf[StructType]
+ applySchemaToPythonRDD(rdd, schema)
+ }
+
+ /**
+ * Apply a schema defined by the schema to an RDD. It is only used by PySpark.
+ */
+ private[sql] def applySchemaToPythonRDD(
+ rdd: RDD[Map[String, _]],
+ schema: StructType): SchemaRDD = {
+ // TODO: We should have a better implementation once we do not turn a Python side record
+ // to a Map.
+ import scala.collection.JavaConversions._
+ import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper}
+
+ def needsConversion(dataType: DataType): Boolean = dataType match {
+ case ByteType => true
+ case ShortType => true
+ case FloatType => true
+ case TimestampType => true
+ case ArrayType(_, _) => true
+ case MapType(_, _, _) => true
+ case StructType(_) => true
+ case other => false
}
- // convert JList, JArray into Seq, convert JMap into Map
- // convert Calendar into Timestamp
- def transform(obj: Any): Any = obj match {
- case c: java.util.List[_] => c.map(transform).toSeq
- case c: java.util.Map[_, _] => c.map {
- case (key, value) => (key, transform(value))
- }.toMap
- case c if c.getClass.isArray =>
- c.asInstanceOf[Array[_]].map(transform).toSeq
- case c: java.util.Calendar =>
- new java.sql.Timestamp(c.getTime().getTime())
- case c => c
+ // Converts value to the type specified by the data type.
+ // Because Python does not have data types for TimestampType, FloatType, ShortType, and
+ // ByteType, we need to explicitly convert values in columns of these data types to the desired
+ // JVM data types.
+ def convert(obj: Any, dataType: DataType): Any = (obj, dataType) match {
+ // TODO: We should check nullable
+ case (null, _) => null
+
+ case (c: java.util.List[_], ArrayType(elementType, _)) =>
+ val converted = c.map { e => convert(e, elementType)}
+ JListWrapper(converted)
+
+ case (c: java.util.Map[_, _], struct: StructType) =>
+ val row = new GenericMutableRow(struct.fields.length)
+ struct.fields.zipWithIndex.foreach {
+ case (field, i) =>
+ val value = convert(c.get(field.name), field.dataType)
+ row.update(i, value)
+ }
+ row
+
+ case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) =>
+ val converted = c.map {
+ case (key, value) =>
+ (convert(key, keyType), convert(value, valueType))
+ }
+ JMapWrapper(converted)
+
+ case (c, ArrayType(elementType, _)) if c.getClass.isArray =>
+ val converted = c.asInstanceOf[Array[_]].map(e => convert(e, elementType))
+ converted: Seq[Any]
+
+ case (c: java.util.Calendar, TimestampType) => new java.sql.Timestamp(c.getTime().getTime())
+ case (c: Int, ByteType) => c.toByte
+ case (c: Int, ShortType) => c.toShort
+ case (c: Double, FloatType) => c.toFloat
+
+ case (c, _) => c
+ }
+
+ val convertedRdd = if (schema.fields.exists(f => needsConversion(f.dataType))) {
+ rdd.map(m => m.map { case (key, value) => (key, convert(value, schema(key).dataType)) })
+ } else {
+ rdd
}
- val need = firstRow.exists {case (key, value) => needTransform(value)}
- val transformed = if (need) {
- rdd.mapPartitions { iter =>
- iter.map {
- m => m.map {case (key, value) => (key, transform(value))}
+ val rowRdd = convertedRdd.mapPartitions { iter =>
+ val row = new GenericMutableRow(schema.fields.length)
+ val fieldsWithIndex = schema.fields.zipWithIndex
+ iter.map { m =>
+ // We cannot use m.values because the order of values returned by m.values may not
+ // match fields order.
+ fieldsWithIndex.foreach {
+ case (field, i) =>
+ val value =
+ m.get(field.name).flatMap(v => Option(v)).map(v => convert(v, field.dataType)).orNull
+ row.update(i, value)
}
- }
- } else rdd
- val rowRdd = transformed.mapPartitions { iter =>
- iter.map { map =>
- new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row
+ row: Row
}
}
- new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema, rowRdd))(self))
- }
+ new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRdd))(self))
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 172b6e0e7f..420f21fb9c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql
-import java.util.{Map => JMap, List => JList, Set => JSet}
+import java.util.{Map => JMap, List => JList}
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
@@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
-import org.apache.spark.sql.catalyst.types.{DataType, ArrayType, BooleanType, StructType, MapType}
import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan}
import org.apache.spark.api.java.JavaRDD
@@ -120,6 +119,11 @@ class SchemaRDD(
override protected def getDependencies: Seq[Dependency[_]] =
List(new OneToOneDependency(queryExecution.toRdd))
+ /** Returns the schema of this SchemaRDD (represented by a [[StructType]]).
+ *
+ * @group schema
+ */
+ def schema: StructType = queryExecution.analyzed.schema
// =======================================================================
// Query DSL
@@ -376,6 +380,8 @@ class SchemaRDD(
* Converts a JavaRDD to a PythonRDD. It is used by pyspark.
*/
private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
+ import scala.collection.Map
+
def toJava(obj: Any, dataType: DataType): Any = dataType match {
case struct: StructType => rowToMap(obj.asInstanceOf[Row], struct)
case array: ArrayType => obj match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala
index fd751031b2..6a20def475 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala
@@ -123,9 +123,15 @@ private[sql] trait SchemaRDDLike {
def saveAsTable(tableName: String): Unit =
sqlContext.executePlan(InsertIntoCreatedTable(None, tableName, logicalPlan)).toRdd
- /** Returns the output schema in the tree format. */
- def schemaString: String = queryExecution.analyzed.schemaString
+ /** Returns the schema as a string in the tree format.
+ *
+ * @group schema
+ */
+ def schemaString: String = baseSchemaRDD.schema.treeString
- /** Prints out the schema in the tree format. */
+ /** Prints out the schema.
+ *
+ * @group schema
+ */
def printSchema(): Unit = println(schemaString)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
index 85726bae54..c1c18a0cd0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
@@ -21,14 +21,16 @@ import java.beans.Introspector
import org.apache.hadoop.conf.Configuration
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
+import org.apache.spark.sql.api.java.types.{StructType => JStructType}
import org.apache.spark.sql.json.JsonRDD
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow}
-import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.parquet.ParquetRelation
import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan}
+import org.apache.spark.sql.types.util.DataTypeConversions
+import DataTypeConversions.asScalaDataType;
import org.apache.spark.util.Utils
/**
@@ -96,6 +98,21 @@ class JavaSQLContext(val sqlContext: SQLContext) {
}
/**
+ * :: DeveloperApi ::
+ * Creates a JavaSchemaRDD from an RDD containing Rows by applying a schema to this RDD.
+ * It is important to make sure that the structure of every Row of the provided RDD matches the
+ * provided schema. Otherwise, there will be runtime exception.
+ */
+ @DeveloperApi
+ def applySchema(rowRDD: JavaRDD[Row], schema: JStructType): JavaSchemaRDD = {
+ val scalaRowRDD = rowRDD.rdd.map(r => r.row)
+ val scalaSchema = asScalaDataType(schema).asInstanceOf[StructType]
+ val logicalPlan =
+ SparkLogicalPlan(ExistingRdd(scalaSchema.toAttributes, scalaRowRDD))(sqlContext)
+ new JavaSchemaRDD(sqlContext, logicalPlan)
+ }
+
+ /**
* Loads a parquet file, returning the result as a [[JavaSchemaRDD]].
*/
def parquetFile(path: String): JavaSchemaRDD =
@@ -104,23 +121,49 @@ class JavaSQLContext(val sqlContext: SQLContext) {
ParquetRelation(path, Some(sqlContext.sparkContext.hadoopConfiguration), sqlContext))
/**
- * Loads a JSON file (one object per line), returning the result as a [[JavaSchemaRDD]].
+ * Loads a JSON file (one object per line), returning the result as a JavaSchemaRDD.
* It goes through the entire dataset once to determine the schema.
- *
- * @group userf
*/
def jsonFile(path: String): JavaSchemaRDD =
jsonRDD(sqlContext.sparkContext.textFile(path))
/**
+ * :: Experimental ::
+ * Loads a JSON file (one object per line) and applies the given schema,
+ * returning the result as a JavaSchemaRDD.
+ */
+ @Experimental
+ def jsonFile(path: String, schema: JStructType): JavaSchemaRDD =
+ jsonRDD(sqlContext.sparkContext.textFile(path), schema)
+
+ /**
* Loads an RDD[String] storing JSON objects (one object per record), returning the result as a
- * [[JavaSchemaRDD]].
+ * JavaSchemaRDD.
* It goes through the entire dataset once to determine the schema.
- *
- * @group userf
*/
- def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD =
- new JavaSchemaRDD(sqlContext, JsonRDD.inferSchema(sqlContext, json, 1.0))
+ def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD = {
+ val appliedScalaSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))
+ val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema)
+ val logicalPlan =
+ SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD))(sqlContext)
+ new JavaSchemaRDD(sqlContext, logicalPlan)
+ }
+
+ /**
+ * :: Experimental ::
+ * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema,
+ * returning the result as a JavaSchemaRDD.
+ */
+ @Experimental
+ def jsonRDD(json: JavaRDD[String], schema: JStructType): JavaSchemaRDD = {
+ val appliedScalaSchema =
+ Option(asScalaDataType(schema)).getOrElse(
+ JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))).asInstanceOf[StructType]
+ val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema)
+ val logicalPlan =
+ SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD))(sqlContext)
+ new JavaSchemaRDD(sqlContext, logicalPlan)
+ }
/**
* Registers the given RDD as a temporary table in the catalog. Temporary tables exist only
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala
index 8fbf13b8b0..8245741498 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala
@@ -22,8 +22,11 @@ import java.util.{List => JList}
import org.apache.spark.Partitioner
import org.apache.spark.api.java.{JavaRDDLike, JavaRDD}
import org.apache.spark.api.java.function.{Function => JFunction}
+import org.apache.spark.sql.api.java.types.StructType
+import org.apache.spark.sql.types.util.DataTypeConversions
import org.apache.spark.sql.{SQLContext, SchemaRDD, SchemaRDDLike}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import DataTypeConversions._
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
@@ -53,6 +56,10 @@ class JavaSchemaRDD(
override def toString: String = baseSchemaRDD.toString
+ /** Returns the schema of this JavaSchemaRDD (represented by a StructType). */
+ def schema: StructType =
+ asJavaDataType(baseSchemaRDD.schema).asInstanceOf[StructType]
+
// =======================================================================
// Base RDD functions that do NOT change schema
// =======================================================================
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
index 9b0dd21761..6c67934bda 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
@@ -17,6 +17,11 @@
package org.apache.spark.sql.api.java
+import scala.annotation.varargs
+import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper}
+import scala.collection.JavaConversions
+import scala.math.BigDecimal
+
import org.apache.spark.sql.catalyst.expressions.{Row => ScalaRow}
/**
@@ -29,7 +34,7 @@ class Row(private[spark] val row: ScalaRow) extends Serializable {
/** Returns the value of column `i`. */
def get(i: Int): Any =
- row(i)
+ Row.toJavaValue(row(i))
/** Returns true if value at column `i` is NULL. */
def isNullAt(i: Int) = get(i) == null
@@ -89,5 +94,57 @@ class Row(private[spark] val row: ScalaRow) extends Serializable {
*/
def getString(i: Int): String =
row.getString(i)
+
+ def canEqual(other: Any): Boolean = other.isInstanceOf[Row]
+
+ override def equals(other: Any): Boolean = other match {
+ case that: Row =>
+ (that canEqual this) &&
+ row == that.row
+ case _ => false
+ }
+
+ override def hashCode(): Int = row.hashCode()
}
+object Row {
+
+ private def toJavaValue(value: Any): Any = value match {
+ // For values of this ScalaRow, we will do the conversion when
+ // they are actually accessed.
+ case row: ScalaRow => new Row(row)
+ case map: scala.collection.Map[_, _] =>
+ JavaConversions.mapAsJavaMap(
+ map.map {
+ case (key, value) => (toJavaValue(key), toJavaValue(value))
+ }
+ )
+ case seq: scala.collection.Seq[_] =>
+ JavaConversions.seqAsJavaList(seq.map(toJavaValue))
+ case decimal: BigDecimal => decimal.underlying()
+ case other => other
+ }
+
+ // TODO: Consolidate the toScalaValue at here with the scalafy in JsonRDD?
+ private def toScalaValue(value: Any): Any = value match {
+ // Values of this row have been converted to Scala values.
+ case row: Row => row.row
+ case map: java.util.Map[_, _] =>
+ JMapWrapper(map).map {
+ case (key, value) => (toScalaValue(key), toScalaValue(value))
+ }
+ case list: java.util.List[_] =>
+ JListWrapper(list).map(toScalaValue)
+ case decimal: java.math.BigDecimal => BigDecimal(decimal)
+ case other => other
+ }
+
+ /**
+ * Creates a Row with the given values.
+ */
+ @varargs def create(values: Any*): Row = {
+ // Right now, we cannot use @varargs to annotate the constructor of
+ // org.apache.spark.sql.api.java.Row. See https://issues.scala-lang.org/browse/SI-8383.
+ new Row(ScalaRow(values.map(toScalaValue):_*))
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
index 6c2b553bb9..bd29ee421b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
@@ -25,33 +25,25 @@ import com.fasterxml.jackson.databind.ObjectMapper
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.types._
-import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan}
-import org.apache.spark.sql.{SQLContext, Logging}
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.Logging
private[sql] object JsonRDD extends Logging {
+ private[sql] def jsonStringToRow(
+ json: RDD[String],
+ schema: StructType): RDD[Row] = {
+ parseJson(json).map(parsed => asRow(parsed, schema))
+ }
+
private[sql] def inferSchema(
- sqlContext: SQLContext,
json: RDD[String],
- samplingRatio: Double = 1.0): LogicalPlan = {
+ samplingRatio: Double = 1.0): StructType = {
require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0")
val schemaData = if (samplingRatio > 0.99) json else json.sample(false, samplingRatio, 1)
val allKeys = parseJson(schemaData).map(allKeysWithValueTypes).reduce(_ ++ _)
- val baseSchema = createSchema(allKeys)
-
- createLogicalPlan(json, baseSchema, sqlContext)
- }
-
- private def createLogicalPlan(
- json: RDD[String],
- baseSchema: StructType,
- sqlContext: SQLContext): LogicalPlan = {
- val schema = nullTypeToStringType(baseSchema)
-
- SparkLogicalPlan(
- ExistingRdd(asAttributes(schema), parseJson(json).map(asRow(_, schema))))(sqlContext)
+ createSchema(allKeys)
}
private def createSchema(allKeys: Set[(String, DataType)]): StructType = {
@@ -75,8 +67,8 @@ private[sql] object JsonRDD extends Logging {
val (topLevel, structLike) = values.partition(_.size == 1)
val topLevelFields = topLevel.filter {
name => resolved.get(prefix ++ name).get match {
- case ArrayType(StructType(Nil)) => false
- case ArrayType(_) => true
+ case ArrayType(StructType(Nil), _) => false
+ case ArrayType(_, _) => true
case struct: StructType => false
case _ => true
}
@@ -90,7 +82,8 @@ private[sql] object JsonRDD extends Logging {
val structType = makeStruct(nestedFields, prefix :+ name)
val dataType = resolved.get(prefix :+ name).get
dataType match {
- case array: ArrayType => Some(StructField(name, ArrayType(structType), nullable = true))
+ case array: ArrayType =>
+ Some(StructField(name, ArrayType(structType, array.containsNull), nullable = true))
case struct: StructType => Some(StructField(name, structType, nullable = true))
// dataType is StringType means that we have resolved type conflicts involving
// primitive types and complex types. So, the type of name has been relaxed to
@@ -109,6 +102,22 @@ private[sql] object JsonRDD extends Logging {
makeStruct(resolved.keySet.toSeq, Nil)
}
+ private[sql] def nullTypeToStringType(struct: StructType): StructType = {
+ val fields = struct.fields.map {
+ case StructField(fieldName, dataType, nullable) => {
+ val newType = dataType match {
+ case NullType => StringType
+ case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull)
+ case struct: StructType => nullTypeToStringType(struct)
+ case other: DataType => other
+ }
+ StructField(fieldName, newType, nullable)
+ }
+ }
+
+ StructType(fields)
+ }
+
/**
* Returns the most general data type for two given data types.
*/
@@ -139,8 +148,8 @@ private[sql] object JsonRDD extends Logging {
case StructField(name, _, _) => name
})
}
- case (ArrayType(elementType1), ArrayType(elementType2)) =>
- ArrayType(compatibleType(elementType1, elementType2))
+ case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
+ ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
// TODO: We should use JsonObjectStringType to mark that values of field will be
// strings and every string is a Json object.
case (_, _) => StringType
@@ -148,18 +157,13 @@ private[sql] object JsonRDD extends Logging {
}
}
- private def typeOfPrimitiveValue(value: Any): DataType = {
- value match {
- case value: java.lang.String => StringType
- case value: java.lang.Integer => IntegerType
- case value: java.lang.Long => LongType
+ private def typeOfPrimitiveValue: PartialFunction[Any, DataType] = {
+ ScalaReflection.typeOfObject orElse {
// Since we do not have a data type backed by BigInteger,
// when we see a Java BigInteger, we use DecimalType.
case value: java.math.BigInteger => DecimalType
- case value: java.lang.Double => DoubleType
+ // DecimalType's JVMType is scala BigDecimal.
case value: java.math.BigDecimal => DecimalType
- case value: java.lang.Boolean => BooleanType
- case null => NullType
// Unexpected data type.
case _ => StringType
}
@@ -172,12 +176,13 @@ private[sql] object JsonRDD extends Logging {
* treat the element as String.
*/
private def typeOfArray(l: Seq[Any]): ArrayType = {
+ val containsNull = l.exists(v => v == null)
val elements = l.flatMap(v => Option(v))
if (elements.isEmpty) {
// If this JSON array is empty, we use NullType as a placeholder.
// If this array is not empty in other JSON objects, we can resolve
// the type after we have passed through all JSON objects.
- ArrayType(NullType)
+ ArrayType(NullType, containsNull)
} else {
val elementType = elements.map {
e => e match {
@@ -189,7 +194,7 @@ private[sql] object JsonRDD extends Logging {
}
}.reduce((type1: DataType, type2: DataType) => compatibleType(type1, type2))
- ArrayType(elementType)
+ ArrayType(elementType, containsNull)
}
}
@@ -216,15 +221,16 @@ private[sql] object JsonRDD extends Logging {
case (key: String, array: Seq[_]) => {
// The value associated with the key is an array.
typeOfArray(array) match {
- case ArrayType(StructType(Nil)) => {
+ case ArrayType(StructType(Nil), containsNull) => {
// The elements of this arrays are structs.
array.asInstanceOf[Seq[Map[String, Any]]].flatMap {
element => allKeysWithValueTypes(element)
}.map {
case (k, dataType) => (s"$key.$k", dataType)
- } :+ (key, ArrayType(StructType(Nil)))
+ } :+ (key, ArrayType(StructType(Nil), containsNull))
}
- case ArrayType(elementType) => (key, ArrayType(elementType)) :: Nil
+ case ArrayType(elementType, containsNull) =>
+ (key, ArrayType(elementType, containsNull)) :: Nil
}
}
case (key: String, value) => (key, typeOfPrimitiveValue(value)) :: Nil
@@ -262,8 +268,11 @@ private[sql] object JsonRDD extends Logging {
// the ObjectMapper will take the last value associated with this duplicate key.
// For example: for {"key": 1, "key":2}, we will get "key"->2.
val mapper = new ObjectMapper()
- iter.map(record => mapper.readValue(record, classOf[java.util.Map[String, Any]]))
- }).map(scalafy).map(_.asInstanceOf[Map[String, Any]])
+ iter.map { record =>
+ val parsed = scalafy(mapper.readValue(record, classOf[java.util.Map[String, Any]]))
+ parsed.asInstanceOf[Map[String, Any]]
+ }
+ })
}
private def toLong(value: Any): Long = {
@@ -334,7 +343,7 @@ private[sql] object JsonRDD extends Logging {
null
} else {
desiredType match {
- case ArrayType(elementType) =>
+ case ArrayType(elementType, _) =>
value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType))
case StringType => toString(value)
case IntegerType => value.asInstanceOf[IntegerType.JvmType]
@@ -348,6 +357,7 @@ private[sql] object JsonRDD extends Logging {
}
private def asRow(json: Map[String,Any], schema: StructType): Row = {
+ // TODO: Reuse the row instead of creating a new one for every record.
val row = new GenericMutableRow(schema.fields.length)
schema.fields.zipWithIndex.foreach {
// StructType
@@ -356,7 +366,7 @@ private[sql] object JsonRDD extends Logging {
v => asRow(v.asInstanceOf[Map[String, Any]], fields)).orNull)
// ArrayType(StructType)
- case (StructField(name, ArrayType(structType: StructType), _), i) =>
+ case (StructField(name, ArrayType(structType: StructType, _), _), i) =>
row.update(i,
json.get(name).flatMap(v => Option(v)).map(
v => v.asInstanceOf[Seq[Any]].map(
@@ -370,32 +380,4 @@ private[sql] object JsonRDD extends Logging {
row
}
-
- private def nullTypeToStringType(struct: StructType): StructType = {
- val fields = struct.fields.map {
- case StructField(fieldName, dataType, nullable) => {
- val newType = dataType match {
- case NullType => StringType
- case ArrayType(NullType) => ArrayType(StringType)
- case struct: StructType => nullTypeToStringType(struct)
- case other: DataType => other
- }
- StructField(fieldName, newType, nullable)
- }
- }
-
- StructType(fields)
- }
-
- private def asAttributes(struct: StructType): Seq[AttributeReference] = {
- struct.fields.map(f => AttributeReference(f.name, f.dataType, nullable = true)())
- }
-
- private def asStruct(attributes: Seq[AttributeReference]): StructType = {
- val fields = attributes.map {
- case AttributeReference(name, dataType, nullable) => StructField(name, dataType, nullable)
- }
-
- StructType(fields)
- }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/package-info.java b/sql/core/src/main/scala/org/apache/spark/sql/package-info.java
index 5360361451..5360361451 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/package-info.java
+++ b/sql/core/src/main/scala/org/apache/spark/sql/package-info.java
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
new file mode 100644
index 0000000000..0995a4eb62
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * Allows the execution of relational queries, including those expressed in SQL using Spark.
+ *
+ * @groupname dataType Data types
+ * @groupdesc Spark SQL data types.
+ * @groupprio dataType -3
+ * @groupname field Field
+ * @groupprio field -2
+ * @groupname row Row
+ * @groupprio row -1
+ */
+package object sql {
+
+ protected[sql] type Logging = com.typesafe.scalalogging.slf4j.Logging
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * Represents one row of output from a relational operator.
+ * @group row
+ */
+ @DeveloperApi
+ type Row = catalyst.expressions.Row
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * A [[Row]] object can be constructed by providing field values. Example:
+ * {{{
+ * import org.apache.spark.sql._
+ *
+ * // Create a Row from values.
+ * Row(value1, value2, value3, ...)
+ * // Create a Row from a Seq of values.
+ * Row.fromSeq(Seq(value1, value2, ...))
+ * }}}
+ *
+ * A value of a row can be accessed through both generic access by ordinal,
+ * which will incur boxing overhead for primitives, as well as native primitive access.
+ * An example of generic access by ordinal:
+ * {{{
+ * import org.apache.spark.sql._
+ *
+ * val row = Row(1, true, "a string", null)
+ * // row: Row = [1,true,a string,null]
+ * val firstValue = row(0)
+ * // firstValue: Any = 1
+ * val fourthValue = row(3)
+ * // fourthValue: Any = null
+ * }}}
+ *
+ * For native primitive access, it is invalid to use the native primitive interface to retrieve
+ * a value that is null, instead a user must check `isNullAt` before attempting to retrieve a
+ * value that might be null.
+ * An example of native primitive access:
+ * {{{
+ * // using the row from the previous example.
+ * val firstValue = row.getInt(0)
+ * // firstValue: Int = 1
+ * val isNull = row.isNullAt(3)
+ * // isNull: Boolean = true
+ * }}}
+ *
+ * Interfaces related to native primitive access are:
+ *
+ * `isNullAt(i: Int): Boolean`
+ *
+ * `getInt(i: Int): Int`
+ *
+ * `getLong(i: Int): Long`
+ *
+ * `getDouble(i: Int): Double`
+ *
+ * `getFloat(i: Int): Float`
+ *
+ * `getBoolean(i: Int): Boolean`
+ *
+ * `getShort(i: Int): Short`
+ *
+ * `getByte(i: Int): Byte`
+ *
+ * `getString(i: Int): String`
+ *
+ * Fields in a [[Row]] object can be extracted in a pattern match. Example:
+ * {{{
+ * import org.apache.spark.sql._
+ *
+ * val pairs = sql("SELECT key, value FROM src").rdd.map {
+ * case Row(key: Int, value: String) =>
+ * key -> value
+ * }
+ * }}}
+ *
+ * @group row
+ */
+ @DeveloperApi
+ val Row = catalyst.expressions.Row
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * The base type of all Spark SQL data types.
+ *
+ * @group dataType
+ */
+ @DeveloperApi
+ type DataType = catalyst.types.DataType
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * The data type representing `String` values
+ *
+ * @group dataType
+ */
+ @DeveloperApi
+ val StringType = catalyst.types.StringType
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * The data type representing `Array[Byte]` values.
+ *
+ * @group dataType
+ */
+ @DeveloperApi
+ val BinaryType = catalyst.types.BinaryType
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * The data type representing `Boolean` values.
+ *
+ *@group dataType
+ */
+ @DeveloperApi
+ val BooleanType = catalyst.types.BooleanType
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * The data type representing `java.sql.Timestamp` values.
+ *
+ * @group dataType
+ */
+ @DeveloperApi
+ val TimestampType = catalyst.types.TimestampType
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * The data type representing `scala.math.BigDecimal` values.
+ *
+ * @group dataType
+ */
+ @DeveloperApi
+ val DecimalType = catalyst.types.DecimalType
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * The data type representing `Double` values.
+ *
+ * @group dataType
+ */
+ @DeveloperApi
+ val DoubleType = catalyst.types.DoubleType
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * The data type representing `Float` values.
+ *
+ * @group dataType
+ */
+ @DeveloperApi
+ val FloatType = catalyst.types.FloatType
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * The data type representing `Byte` values.
+ *
+ * @group dataType
+ */
+ @DeveloperApi
+ val ByteType = catalyst.types.ByteType
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * The data type representing `Int` values.
+ *
+ * @group dataType
+ */
+ @DeveloperApi
+ val IntegerType = catalyst.types.IntegerType
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * The data type representing `Long` values.
+ *
+ * @group dataType
+ */
+ @DeveloperApi
+ val LongType = catalyst.types.LongType
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * The data type representing `Short` values.
+ *
+ * @group dataType
+ */
+ @DeveloperApi
+ val ShortType = catalyst.types.ShortType
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * The data type for collections of multiple values.
+ * Internally these are represented as columns that contain a ``scala.collection.Seq``.
+ *
+ * An [[ArrayType]] object comprises two fields, `elementType: [[DataType]]` and
+ * `containsNull: Boolean`. The field of `elementType` is used to specify the type of
+ * array elements. The field of `containsNull` is used to specify if the array has `null` values.
+ *
+ * @group dataType
+ */
+ @DeveloperApi
+ type ArrayType = catalyst.types.ArrayType
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * An [[ArrayType]] object can be constructed with two ways,
+ * {{{
+ * ArrayType(elementType: DataType, containsNull: Boolean)
+ * }}} and
+ * {{{
+ * ArrayType(elementType: DataType)
+ * }}}
+ * For `ArrayType(elementType)`, the field of `containsNull` is set to `false`.
+ *
+ * @group dataType
+ */
+ @DeveloperApi
+ val ArrayType = catalyst.types.ArrayType
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * The data type representing `Map`s. A [[MapType]] object comprises three fields,
+ * `keyType: [[DataType]]`, `valueType: [[DataType]]` and `valueContainsNull: Boolean`.
+ * The field of `keyType` is used to specify the type of keys in the map.
+ * The field of `valueType` is used to specify the type of values in the map.
+ * The field of `valueContainsNull` is used to specify if values of this map has `null` values.
+ * For values of a MapType column, keys are not allowed to have `null` values.
+ *
+ * @group dataType
+ */
+ @DeveloperApi
+ type MapType = catalyst.types.MapType
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * A [[MapType]] object can be constructed with two ways,
+ * {{{
+ * MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean)
+ * }}} and
+ * {{{
+ * MapType(keyType: DataType, valueType: DataType)
+ * }}}
+ * For `MapType(keyType: DataType, valueType: DataType)`,
+ * the field of `valueContainsNull` is set to `true`.
+ *
+ * @group dataType
+ */
+ @DeveloperApi
+ val MapType = catalyst.types.MapType
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * The data type representing [[Row]]s.
+ * A [[StructType]] object comprises a [[Seq]] of [[StructField]]s.
+ *
+ * @group dataType
+ */
+ @DeveloperApi
+ type StructType = catalyst.types.StructType
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * A [[StructType]] object can be constructed by
+ * {{{
+ * StructType(fields: Seq[StructField])
+ * }}}
+ * For a [[StructType]] object, one or multiple [[StructField]]s can be extracted by names.
+ * If multiple [[StructField]]s are extracted, a [[StructType]] object will be returned.
+ * If a provided name does not have a matching field, it will be ignored. For the case
+ * of extracting a single StructField, a `null` will be returned.
+ * Example:
+ * {{{
+ * import org.apache.spark.sql._
+ *
+ * val struct =
+ * StructType(
+ * StructField("a", IntegerType, true) ::
+ * StructField("b", LongType, false) ::
+ * StructField("c", BooleanType, false) :: Nil)
+ *
+ * // Extract a single StructField.
+ * val singleField = struct("b")
+ * // singleField: StructField = StructField(b,LongType,false)
+ *
+ * // This struct does not have a field called "d". null will be returned.
+ * val nonExisting = struct("d")
+ * // nonExisting: StructField = null
+ *
+ * // Extract multiple StructFields. Field names are provided in a set.
+ * // A StructType object will be returned.
+ * val twoFields = struct(Set("b", "c"))
+ * // twoFields: StructType =
+ * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false)))
+ *
+ * // Those names do not have matching fields will be ignored.
+ * // For the case shown below, "d" will be ignored and
+ * // it is treated as struct(Set("b", "c")).
+ * val ignoreNonExisting = struct(Set("b", "c", "d"))
+ * // ignoreNonExisting: StructType =
+ * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false)))
+ * }}}
+ *
+ * A [[Row]] object is used as a value of the StructType.
+ * Example:
+ * {{{
+ * import org.apache.spark.sql._
+ *
+ * val innerStruct =
+ * StructType(
+ * StructField("f1", IntegerType, true) ::
+ * StructField("f2", LongType, false) ::
+ * StructField("f3", BooleanType, false) :: Nil)
+ *
+ * val struct = StructType(
+ * StructField("a", innerStruct, true) :: Nil)
+ *
+ * // Create a Row with the schema defined by struct
+ * val row = Row(Row(1, 2, true))
+ * // row: Row = [[1,2,true]]
+ * }}}
+ *
+ * @group dataType
+ */
+ @DeveloperApi
+ val StructType = catalyst.types.StructType
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * A [[StructField]] object represents a field in a [[StructType]] object.
+ * A [[StructField]] object comprises three fields, `name: [[String]]`, `dataType: [[DataType]]`,
+ * and `nullable: Boolean`. The field of `name` is the name of a `StructField`. The field of
+ * `dataType` specifies the data type of a `StructField`.
+ * The field of `nullable` specifies if values of a `StructField` can contain `null` values.
+ *
+ * @group field
+ */
+ @DeveloperApi
+ type StructField = catalyst.types.StructField
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * A [[StructField]] object can be constructed by
+ * {{{
+ * StructField(name: String, dataType: DataType, nullable: Boolean)
+ * }}}
+ *
+ * @group dataType
+ */
+ @DeveloperApi
+ val StructField = catalyst.types.StructField
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
index de8fe2dae3..0a3b59cbc2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
@@ -75,21 +75,21 @@ private[sql] object CatalystConverter {
val fieldType: DataType = field.dataType
fieldType match {
// For native JVM types we use a converter with native arrays
- case ArrayType(elementType: NativeType) => {
+ case ArrayType(elementType: NativeType, false) => {
new CatalystNativeArrayConverter(elementType, fieldIndex, parent)
}
// This is for other types of arrays, including those with nested fields
- case ArrayType(elementType: DataType) => {
+ case ArrayType(elementType: DataType, false) => {
new CatalystArrayConverter(elementType, fieldIndex, parent)
}
case StructType(fields: Seq[StructField]) => {
new CatalystStructConverter(fields.toArray, fieldIndex, parent)
}
- case MapType(keyType: DataType, valueType: DataType) => {
+ case MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) => {
new CatalystMapConverter(
Array(
new FieldType(MAP_KEY_SCHEMA_NAME, keyType, false),
- new FieldType(MAP_VALUE_SCHEMA_NAME, valueType, true)),
+ new FieldType(MAP_VALUE_SCHEMA_NAME, valueType, valueContainsNull)),
fieldIndex,
parent)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
index 39294a3f4b..6d4ce32ac5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
@@ -172,10 +172,10 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
private[parquet] def writeValue(schema: DataType, value: Any): Unit = {
if (value != null) {
schema match {
- case t @ ArrayType(_) => writeArray(
+ case t @ ArrayType(_, false) => writeArray(
t,
value.asInstanceOf[CatalystConverter.ArrayScalaType[_]])
- case t @ MapType(_, _) => writeMap(
+ case t @ MapType(_, _, _) => writeMap(
t,
value.asInstanceOf[CatalystConverter.MapScalaType[_, _]])
case t @ StructType(_) => writeStruct(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
index 58370b955a..aaef1a1d47 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
@@ -116,7 +116,7 @@ private[parquet] object ParquetTypesConverter extends Logging {
case ParquetOriginalType.LIST => { // TODO: check enums!
assert(groupType.getFieldCount == 1)
val field = groupType.getFields.apply(0)
- new ArrayType(toDataType(field))
+ ArrayType(toDataType(field), containsNull = false)
}
case ParquetOriginalType.MAP => {
assert(
@@ -130,7 +130,9 @@ private[parquet] object ParquetTypesConverter extends Logging {
assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED)
val valueType = toDataType(keyValueGroup.getFields.apply(1))
assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED)
- new MapType(keyType, valueType)
+ // TODO: set valueContainsNull explicitly instead of assuming valueContainsNull is true
+ // at here.
+ MapType(keyType, valueType)
}
case _ => {
// Note: the order of these checks is important!
@@ -140,10 +142,12 @@ private[parquet] object ParquetTypesConverter extends Logging {
assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED)
val valueType = toDataType(keyValueGroup.getFields.apply(1))
assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED)
- new MapType(keyType, valueType)
+ // TODO: set valueContainsNull explicitly instead of assuming valueContainsNull is true
+ // at here.
+ MapType(keyType, valueType)
} else if (correspondsToArray(groupType)) { // ArrayType
val elementType = toDataType(groupType.getFields.apply(0))
- new ArrayType(elementType)
+ ArrayType(elementType, containsNull = false)
} else { // everything else: StructType
val fields = groupType
.getFields
@@ -151,7 +155,7 @@ private[parquet] object ParquetTypesConverter extends Logging {
ptype.getName,
toDataType(ptype),
ptype.getRepetition != Repetition.REQUIRED))
- new StructType(fields)
+ StructType(fields)
}
}
}
@@ -234,7 +238,7 @@ private[parquet] object ParquetTypesConverter extends Logging {
new ParquetPrimitiveType(repetition, primitiveType, name, originalType.orNull)
}.getOrElse {
ctype match {
- case ArrayType(elementType) => {
+ case ArrayType(elementType, false) => {
val parquetElementType = fromDataType(
elementType,
CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME,
@@ -248,7 +252,7 @@ private[parquet] object ParquetTypesConverter extends Logging {
}
new ParquetGroupType(repetition, name, fields)
}
- case MapType(keyType, valueType) => {
+ case MapType(keyType, valueType, _) => {
val parquetKeyType =
fromDataType(
keyType,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
new file mode 100644
index 0000000000..d1aa3c8d53
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types.util
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.api.java.types.{DataType => JDataType, StructField => JStructField}
+
+import scala.collection.JavaConverters._
+
+protected[sql] object DataTypeConversions {
+
+ /**
+ * Returns the equivalent StructField in Scala for the given StructField in Java.
+ */
+ def asJavaStructField(scalaStructField: StructField): JStructField = {
+ JDataType.createStructField(
+ scalaStructField.name,
+ asJavaDataType(scalaStructField.dataType),
+ scalaStructField.nullable)
+ }
+
+ /**
+ * Returns the equivalent DataType in Java for the given DataType in Scala.
+ */
+ def asJavaDataType(scalaDataType: DataType): JDataType = scalaDataType match {
+ case StringType => JDataType.StringType
+ case BinaryType => JDataType.BinaryType
+ case BooleanType => JDataType.BooleanType
+ case TimestampType => JDataType.TimestampType
+ case DecimalType => JDataType.DecimalType
+ case DoubleType => JDataType.DoubleType
+ case FloatType => JDataType.FloatType
+ case ByteType => JDataType.ByteType
+ case IntegerType => JDataType.IntegerType
+ case LongType => JDataType.LongType
+ case ShortType => JDataType.ShortType
+
+ case arrayType: ArrayType => JDataType.createArrayType(
+ asJavaDataType(arrayType.elementType), arrayType.containsNull)
+ case mapType: MapType => JDataType.createMapType(
+ asJavaDataType(mapType.keyType),
+ asJavaDataType(mapType.valueType),
+ mapType.valueContainsNull)
+ case structType: StructType => JDataType.createStructType(
+ structType.fields.map(asJavaStructField).asJava)
+ }
+
+ /**
+ * Returns the equivalent StructField in Scala for the given StructField in Java.
+ */
+ def asScalaStructField(javaStructField: JStructField): StructField = {
+ StructField(
+ javaStructField.getName,
+ asScalaDataType(javaStructField.getDataType),
+ javaStructField.isNullable)
+ }
+
+ /**
+ * Returns the equivalent DataType in Scala for the given DataType in Java.
+ */
+ def asScalaDataType(javaDataType: JDataType): DataType = javaDataType match {
+ case stringType: org.apache.spark.sql.api.java.types.StringType =>
+ StringType
+ case binaryType: org.apache.spark.sql.api.java.types.BinaryType =>
+ BinaryType
+ case booleanType: org.apache.spark.sql.api.java.types.BooleanType =>
+ BooleanType
+ case timestampType: org.apache.spark.sql.api.java.types.TimestampType =>
+ TimestampType
+ case decimalType: org.apache.spark.sql.api.java.types.DecimalType =>
+ DecimalType
+ case doubleType: org.apache.spark.sql.api.java.types.DoubleType =>
+ DoubleType
+ case floatType: org.apache.spark.sql.api.java.types.FloatType =>
+ FloatType
+ case byteType: org.apache.spark.sql.api.java.types.ByteType =>
+ ByteType
+ case integerType: org.apache.spark.sql.api.java.types.IntegerType =>
+ IntegerType
+ case longType: org.apache.spark.sql.api.java.types.LongType =>
+ LongType
+ case shortType: org.apache.spark.sql.api.java.types.ShortType =>
+ ShortType
+
+ case arrayType: org.apache.spark.sql.api.java.types.ArrayType =>
+ ArrayType(asScalaDataType(arrayType.getElementType), arrayType.isContainsNull)
+ case mapType: org.apache.spark.sql.api.java.types.MapType =>
+ MapType(
+ asScalaDataType(mapType.getKeyType),
+ asScalaDataType(mapType.getValueType),
+ mapType.isValueContainsNull)
+ case structType: org.apache.spark.sql.api.java.types.StructType =>
+ StructType(structType.getFields.map(asScalaStructField))
+ }
+}
diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java
new file mode 100644
index 0000000000..8ee4591105
--- /dev/null
+++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java;
+
+import java.io.Serializable;
+import java.math.BigDecimal;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.sql.api.java.types.DataType;
+import org.apache.spark.sql.api.java.types.StructField;
+import org.apache.spark.sql.api.java.types.StructType;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+
+// The test suite itself is Serializable so that anonymous Function implementations can be
+// serialized, as an alternative to converting these anonymous classes to static inner classes;
+// see http://stackoverflow.com/questions/758570/.
+public class JavaApplySchemaSuite implements Serializable {
+ private transient JavaSparkContext javaCtx;
+ private transient JavaSQLContext javaSqlCtx;
+
+ @Before
+ public void setUp() {
+ javaCtx = new JavaSparkContext("local", "JavaApplySchemaSuite");
+ javaSqlCtx = new JavaSQLContext(javaCtx);
+ }
+
+ @After
+ public void tearDown() {
+ javaCtx.stop();
+ javaCtx = null;
+ javaSqlCtx = null;
+ }
+
+ public static class Person implements Serializable {
+ private String name;
+ private int age;
+
+ public String getName() {
+ return name;
+ }
+
+ public void setName(String name) {
+ this.name = name;
+ }
+
+ public int getAge() {
+ return age;
+ }
+
+ public void setAge(int age) {
+ this.age = age;
+ }
+ }
+
+ @Test
+ public void applySchema() {
+ List<Person> personList = new ArrayList<Person>(2);
+ Person person1 = new Person();
+ person1.setName("Michael");
+ person1.setAge(29);
+ personList.add(person1);
+ Person person2 = new Person();
+ person2.setName("Yin");
+ person2.setAge(28);
+ personList.add(person2);
+
+ JavaRDD<Row> rowRDD = javaCtx.parallelize(personList).map(
+ new Function<Person, Row>() {
+ public Row call(Person person) throws Exception {
+ return Row.create(person.getName(), person.getAge());
+ }
+ });
+
+ List<StructField> fields = new ArrayList<StructField>(2);
+ fields.add(DataType.createStructField("name", DataType.StringType, false));
+ fields.add(DataType.createStructField("age", DataType.IntegerType, false));
+ StructType schema = DataType.createStructType(fields);
+
+ JavaSchemaRDD schemaRDD = javaSqlCtx.applySchema(rowRDD, schema);
+ schemaRDD.registerAsTable("people");
+ List<Row> actual = javaSqlCtx.sql("SELECT * FROM people").collect();
+
+ List<Row> expected = new ArrayList<Row>(2);
+ expected.add(Row.create("Michael", 29));
+ expected.add(Row.create("Yin", 28));
+
+ Assert.assertEquals(expected, actual);
+ }
+
+ @Test
+ public void applySchemaToJSON() {
+ JavaRDD<String> jsonRDD = javaCtx.parallelize(Arrays.asList(
+ "{\"string\":\"this is a simple string.\", \"integer\":10, \"long\":21474836470, " +
+ "\"bigInteger\":92233720368547758070, \"double\":1.7976931348623157E308, " +
+ "\"boolean\":true, \"null\":null}",
+ "{\"string\":\"this is another simple string.\", \"integer\":11, \"long\":21474836469, " +
+ "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " +
+ "\"boolean\":false, \"null\":null}"));
+ List<StructField> fields = new ArrayList<StructField>(7);
+ fields.add(DataType.createStructField("bigInteger", DataType.DecimalType, true));
+ fields.add(DataType.createStructField("boolean", DataType.BooleanType, true));
+ fields.add(DataType.createStructField("double", DataType.DoubleType, true));
+ fields.add(DataType.createStructField("integer", DataType.IntegerType, true));
+ fields.add(DataType.createStructField("long", DataType.LongType, true));
+ fields.add(DataType.createStructField("null", DataType.StringType, true));
+ fields.add(DataType.createStructField("string", DataType.StringType, true));
+ StructType expectedSchema = DataType.createStructType(fields);
+ List<Row> expectedResult = new ArrayList<Row>(2);
+ expectedResult.add(
+ Row.create(
+ new BigDecimal("92233720368547758070"),
+ true,
+ 1.7976931348623157E308,
+ 10,
+ 21474836470L,
+ null,
+ "this is a simple string."));
+ expectedResult.add(
+ Row.create(
+ new BigDecimal("92233720368547758069"),
+ false,
+ 1.7976931348623157E305,
+ 11,
+ 21474836469L,
+ null,
+ "this is another simple string."));
+
+ JavaSchemaRDD schemaRDD1 = javaSqlCtx.jsonRDD(jsonRDD);
+ StructType actualSchema1 = schemaRDD1.schema();
+ Assert.assertEquals(expectedSchema, actualSchema1);
+ schemaRDD1.registerAsTable("jsonTable1");
+ List<Row> actual1 = javaSqlCtx.sql("select * from jsonTable1").collect();
+ Assert.assertEquals(expectedResult, actual1);
+
+ JavaSchemaRDD schemaRDD2 = javaSqlCtx.jsonRDD(jsonRDD, expectedSchema);
+ StructType actualSchema2 = schemaRDD2.schema();
+ Assert.assertEquals(expectedSchema, actualSchema2);
+ schemaRDD1.registerAsTable("jsonTable2");
+ List<Row> actual2 = javaSqlCtx.sql("select * from jsonTable2").collect();
+ Assert.assertEquals(expectedResult, actual2);
+ }
+}
diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java
new file mode 100644
index 0000000000..52d07b5425
--- /dev/null
+++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java;
+
+import java.math.BigDecimal;
+import java.sql.Timestamp;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+public class JavaRowSuite {
+ private byte byteValue;
+ private short shortValue;
+ private int intValue;
+ private long longValue;
+ private float floatValue;
+ private double doubleValue;
+ private BigDecimal decimalValue;
+ private boolean booleanValue;
+ private String stringValue;
+ private byte[] binaryValue;
+ private Timestamp timestampValue;
+
+ @Before
+ public void setUp() {
+ byteValue = (byte)127;
+ shortValue = (short)32767;
+ intValue = 2147483647;
+ longValue = 9223372036854775807L;
+ floatValue = (float)3.4028235E38;
+ doubleValue = 1.7976931348623157E308;
+ decimalValue = new BigDecimal("1.7976931348623157E328");
+ booleanValue = true;
+ stringValue = "this is a string";
+ binaryValue = stringValue.getBytes();
+ timestampValue = Timestamp.valueOf("2014-06-30 09:20:00.0");
+ }
+
+ @Test
+ public void constructSimpleRow() {
+ Row simpleRow = Row.create(
+ byteValue, // ByteType
+ new Byte(byteValue),
+ shortValue, // ShortType
+ new Short(shortValue),
+ intValue, // IntegerType
+ new Integer(intValue),
+ longValue, // LongType
+ new Long(longValue),
+ floatValue, // FloatType
+ new Float(floatValue),
+ doubleValue, // DoubleType
+ new Double(doubleValue),
+ decimalValue, // DecimalType
+ booleanValue, // BooleanType
+ new Boolean(booleanValue),
+ stringValue, // StringType
+ binaryValue, // BinaryType
+ timestampValue, // TimestampType
+ null // null
+ );
+
+ Assert.assertEquals(byteValue, simpleRow.getByte(0));
+ Assert.assertEquals(byteValue, simpleRow.get(0));
+ Assert.assertEquals(byteValue, simpleRow.getByte(1));
+ Assert.assertEquals(byteValue, simpleRow.get(1));
+ Assert.assertEquals(shortValue, simpleRow.getShort(2));
+ Assert.assertEquals(shortValue, simpleRow.get(2));
+ Assert.assertEquals(shortValue, simpleRow.getShort(3));
+ Assert.assertEquals(shortValue, simpleRow.get(3));
+ Assert.assertEquals(intValue, simpleRow.getInt(4));
+ Assert.assertEquals(intValue, simpleRow.get(4));
+ Assert.assertEquals(intValue, simpleRow.getInt(5));
+ Assert.assertEquals(intValue, simpleRow.get(5));
+ Assert.assertEquals(longValue, simpleRow.getLong(6));
+ Assert.assertEquals(longValue, simpleRow.get(6));
+ Assert.assertEquals(longValue, simpleRow.getLong(7));
+ Assert.assertEquals(longValue, simpleRow.get(7));
+ // When we create the row, we do not do any conversion
+ // for a float/double value, so we just set the delta to 0.
+ Assert.assertEquals(floatValue, simpleRow.getFloat(8), 0);
+ Assert.assertEquals(floatValue, simpleRow.get(8));
+ Assert.assertEquals(floatValue, simpleRow.getFloat(9), 0);
+ Assert.assertEquals(floatValue, simpleRow.get(9));
+ Assert.assertEquals(doubleValue, simpleRow.getDouble(10), 0);
+ Assert.assertEquals(doubleValue, simpleRow.get(10));
+ Assert.assertEquals(doubleValue, simpleRow.getDouble(11), 0);
+ Assert.assertEquals(doubleValue, simpleRow.get(11));
+ Assert.assertEquals(decimalValue, simpleRow.get(12));
+ Assert.assertEquals(booleanValue, simpleRow.getBoolean(13));
+ Assert.assertEquals(booleanValue, simpleRow.get(13));
+ Assert.assertEquals(booleanValue, simpleRow.getBoolean(14));
+ Assert.assertEquals(booleanValue, simpleRow.get(14));
+ Assert.assertEquals(stringValue, simpleRow.getString(15));
+ Assert.assertEquals(stringValue, simpleRow.get(15));
+ Assert.assertEquals(binaryValue, simpleRow.get(16));
+ Assert.assertEquals(timestampValue, simpleRow.get(17));
+ Assert.assertEquals(true, simpleRow.isNullAt(18));
+ Assert.assertEquals(null, simpleRow.get(18));
+ }
+
+ @Test
+ public void constructComplexRow() {
+ // Simple array
+ List<String> simpleStringArray = Arrays.asList(
+ stringValue + " (1)", stringValue + " (2)", stringValue + "(3)");
+
+ // Simple map
+ Map<String, Long> simpleMap = new HashMap<String, Long>();
+ simpleMap.put(stringValue + " (1)", longValue);
+ simpleMap.put(stringValue + " (2)", longValue - 1);
+ simpleMap.put(stringValue + " (3)", longValue - 2);
+
+ // Simple struct
+ Row simpleStruct = Row.create(
+ doubleValue, stringValue, timestampValue, null);
+
+ // Complex array
+ List<Map<String, Long>> arrayOfMaps = Arrays.asList(simpleMap);
+ List<Row> arrayOfRows = Arrays.asList(simpleStruct);
+
+ // Complex map
+ Map<List<Row>, Row> complexMap = new HashMap<List<Row>, Row>();
+ complexMap.put(arrayOfRows, simpleStruct);
+
+ // Complex struct
+ Row complexStruct = Row.create(
+ simpleStringArray,
+ simpleMap,
+ simpleStruct,
+ arrayOfMaps,
+ arrayOfRows,
+ complexMap,
+ null);
+ Assert.assertEquals(simpleStringArray, complexStruct.get(0));
+ Assert.assertEquals(simpleMap, complexStruct.get(1));
+ Assert.assertEquals(simpleStruct, complexStruct.get(2));
+ Assert.assertEquals(arrayOfMaps, complexStruct.get(3));
+ Assert.assertEquals(arrayOfRows, complexStruct.get(4));
+ Assert.assertEquals(complexMap, complexStruct.get(5));
+ Assert.assertEquals(null, complexStruct.get(6));
+
+ // A very complex row
+ Row complexRow = Row.create(arrayOfMaps, arrayOfRows, complexMap, complexStruct);
+ Assert.assertEquals(arrayOfMaps, complexRow.get(0));
+ Assert.assertEquals(arrayOfRows, complexRow.get(1));
+ Assert.assertEquals(complexMap, complexRow.get(2));
+ Assert.assertEquals(complexStruct, complexRow.get(3));
+ }
+}
diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java
new file mode 100644
index 0000000000..96a503962f
--- /dev/null
+++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java;
+
+import java.util.List;
+import java.util.ArrayList;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.spark.sql.types.util.DataTypeConversions;
+import org.apache.spark.sql.api.java.types.DataType;
+import org.apache.spark.sql.api.java.types.StructField;
+
+public class JavaSideDataTypeConversionSuite {
+ public void checkDataType(DataType javaDataType) {
+ org.apache.spark.sql.catalyst.types.DataType scalaDataType =
+ DataTypeConversions.asScalaDataType(javaDataType);
+ DataType actual = DataTypeConversions.asJavaDataType(scalaDataType);
+ Assert.assertEquals(javaDataType, actual);
+ }
+
+ @Test
+ public void createDataTypes() {
+ // Simple DataTypes.
+ checkDataType(DataType.StringType);
+ checkDataType(DataType.BinaryType);
+ checkDataType(DataType.BooleanType);
+ checkDataType(DataType.TimestampType);
+ checkDataType(DataType.DecimalType);
+ checkDataType(DataType.DoubleType);
+ checkDataType(DataType.FloatType);
+ checkDataType(DataType.ByteType);
+ checkDataType(DataType.IntegerType);
+ checkDataType(DataType.LongType);
+ checkDataType(DataType.ShortType);
+
+ // Simple ArrayType.
+ DataType simpleJavaArrayType = DataType.createArrayType(DataType.StringType, true);
+ checkDataType(simpleJavaArrayType);
+
+ // Simple MapType.
+ DataType simpleJavaMapType = DataType.createMapType(DataType.StringType, DataType.LongType);
+ checkDataType(simpleJavaMapType);
+
+ // Simple StructType.
+ List<StructField> simpleFields = new ArrayList<StructField>();
+ simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false));
+ simpleFields.add(DataType.createStructField("b", DataType.BooleanType, true));
+ simpleFields.add(DataType.createStructField("c", DataType.LongType, true));
+ simpleFields.add(DataType.createStructField("d", DataType.BinaryType, false));
+ DataType simpleJavaStructType = DataType.createStructType(simpleFields);
+ checkDataType(simpleJavaStructType);
+
+ // Complex StructType.
+ List<StructField> complexFields = new ArrayList<StructField>();
+ complexFields.add(DataType.createStructField("simpleArray", simpleJavaArrayType, true));
+ complexFields.add(DataType.createStructField("simpleMap", simpleJavaMapType, true));
+ complexFields.add(DataType.createStructField("simpleStruct", simpleJavaStructType, true));
+ complexFields.add(DataType.createStructField("boolean", DataType.BooleanType, false));
+ DataType complexJavaStructType = DataType.createStructType(complexFields);
+ checkDataType(complexJavaStructType);
+
+ // Complex ArrayType.
+ DataType complexJavaArrayType = DataType.createArrayType(complexJavaStructType, true);
+ checkDataType(complexJavaArrayType);
+
+ // Complex MapType.
+ DataType complexJavaMapType =
+ DataType.createMapType(complexJavaStructType, complexJavaArrayType, false);
+ checkDataType(complexJavaMapType);
+ }
+
+ @Test
+ public void illegalArgument() {
+ // ArrayType
+ try {
+ DataType.createArrayType(null, true);
+ Assert.fail();
+ } catch (IllegalArgumentException expectedException) {
+ }
+
+ // MapType
+ try {
+ DataType.createMapType(null, DataType.StringType);
+ Assert.fail();
+ } catch (IllegalArgumentException expectedException) {
+ }
+ try {
+ DataType.createMapType(DataType.StringType, null);
+ Assert.fail();
+ } catch (IllegalArgumentException expectedException) {
+ }
+ try {
+ DataType.createMapType(null, null);
+ Assert.fail();
+ } catch (IllegalArgumentException expectedException) {
+ }
+
+ // StructField
+ try {
+ DataType.createStructField(null, DataType.StringType, true);
+ } catch (IllegalArgumentException expectedException) {
+ }
+ try {
+ DataType.createStructField("name", null, true);
+ } catch (IllegalArgumentException expectedException) {
+ }
+ try {
+ DataType.createStructField(null, null, true);
+ } catch (IllegalArgumentException expectedException) {
+ }
+
+ // StructType
+ try {
+ List<StructField> simpleFields = new ArrayList<StructField>();
+ simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false));
+ simpleFields.add(DataType.createStructField("b", DataType.BooleanType, true));
+ simpleFields.add(DataType.createStructField("c", DataType.LongType, true));
+ simpleFields.add(null);
+ DataType.createStructType(simpleFields);
+ Assert.fail();
+ } catch (IllegalArgumentException expectedException) {
+ }
+ try {
+ List<StructField> simpleFields = new ArrayList<StructField>();
+ simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false));
+ simpleFields.add(DataType.createStructField("a", DataType.BooleanType, true));
+ simpleFields.add(DataType.createStructField("c", DataType.LongType, true));
+ DataType.createStructType(simpleFields);
+ Assert.fail();
+ } catch (IllegalArgumentException expectedException) {
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
new file mode 100644
index 0000000000..cf7d79f42d
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
@@ -0,0 +1,58 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql
+
+import org.scalatest.FunSuite
+
+class DataTypeSuite extends FunSuite {
+
+ test("construct an ArrayType") {
+ val array = ArrayType(StringType)
+
+ assert(ArrayType(StringType, false) === array)
+ }
+
+ test("construct an MapType") {
+ val map = MapType(StringType, IntegerType)
+
+ assert(MapType(StringType, IntegerType, true) === map)
+ }
+
+ test("extract fields from a StructType") {
+ val struct = StructType(
+ StructField("a", IntegerType, true) ::
+ StructField("b", LongType, false) ::
+ StructField("c", StringType, true) ::
+ StructField("d", FloatType, true) :: Nil)
+
+ assert(StructField("b", LongType, false) === struct("b"))
+
+ intercept[IllegalArgumentException] {
+ struct("e")
+ }
+
+ val expectedStruct = StructType(
+ StructField("b", LongType, false) ::
+ StructField("d", FloatType, true) :: Nil)
+
+ assert(expectedStruct === struct(Set("b", "d")))
+ intercept[IllegalArgumentException] {
+ struct(Set("b", "d", "e", "f"))
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
new file mode 100644
index 0000000000..651cb735ab
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
@@ -0,0 +1,46 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+
+class RowSuite extends FunSuite {
+
+ test("create row") {
+ val expected = new GenericMutableRow(4)
+ expected.update(0, 2147483647)
+ expected.update(1, "this is a string")
+ expected.update(2, false)
+ expected.update(3, null)
+ val actual1 = Row(2147483647, "this is a string", false, null)
+ assert(expected.size === actual1.size)
+ assert(expected.getInt(0) === actual1.getInt(0))
+ assert(expected.getString(1) === actual1.getString(1))
+ assert(expected.getBoolean(2) === actual1.getBoolean(2))
+ assert(expected(3) === actual1(3))
+
+ val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null))
+ assert(expected.size === actual2.size)
+ assert(expected.getInt(0) === actual2.getInt(0))
+ assert(expected.getString(1) === actual2.getString(1))
+ assert(expected.getBoolean(2) === actual2.getBoolean(2))
+ assert(expected(3) === actual2(3))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index de9e8aa4f6..bebb490645 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -17,9 +17,7 @@
package org.apache.spark.sql
-import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
import org.apache.spark.sql.test._
/* Implicits */
@@ -446,4 +444,66 @@ class SQLQuerySuite extends QueryTest {
)
clear()
}
+
+ test("apply schema") {
+ val schema1 = StructType(
+ StructField("f1", IntegerType, false) ::
+ StructField("f2", StringType, false) ::
+ StructField("f3", BooleanType, false) ::
+ StructField("f4", IntegerType, true) :: Nil)
+
+ val rowRDD1 = unparsedStrings.map { r =>
+ val values = r.split(",").map(_.trim)
+ val v4 = try values(3).toInt catch {
+ case _: NumberFormatException => null
+ }
+ Row(values(0).toInt, values(1), values(2).toBoolean, v4)
+ }
+
+ val schemaRDD1 = applySchema(rowRDD1, schema1)
+ schemaRDD1.registerAsTable("applySchema1")
+ checkAnswer(
+ sql("SELECT * FROM applySchema1"),
+ (1, "A1", true, null) ::
+ (2, "B2", false, null) ::
+ (3, "C3", true, null) ::
+ (4, "D4", true, 2147483644) :: Nil)
+
+ checkAnswer(
+ sql("SELECT f1, f4 FROM applySchema1"),
+ (1, null) ::
+ (2, null) ::
+ (3, null) ::
+ (4, 2147483644) :: Nil)
+
+ val schema2 = StructType(
+ StructField("f1", StructType(
+ StructField("f11", IntegerType, false) ::
+ StructField("f12", BooleanType, false) :: Nil), false) ::
+ StructField("f2", MapType(StringType, IntegerType, true), false) :: Nil)
+
+ val rowRDD2 = unparsedStrings.map { r =>
+ val values = r.split(",").map(_.trim)
+ val v4 = try values(3).toInt catch {
+ case _: NumberFormatException => null
+ }
+ Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4))
+ }
+
+ val schemaRDD2 = applySchema(rowRDD2, schema2)
+ schemaRDD2.registerAsTable("applySchema2")
+ checkAnswer(
+ sql("SELECT * FROM applySchema2"),
+ (Seq(1, true), Map("A1" -> null)) ::
+ (Seq(2, false), Map("B2" -> null)) ::
+ (Seq(3, true), Map("C3" -> null)) ::
+ (Seq(4, true), Map("D4" -> 2147483644)) :: Nil)
+
+ checkAnswer(
+ sql("SELECT f1.f11, f2['D4'] FROM applySchema2"),
+ (1, null) ::
+ (2, null) ::
+ (3, null) ::
+ (4, 2147483644) :: Nil)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index 330b20b315..213190e812 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -128,4 +128,11 @@ object TestData {
case class TableName(tableName: String)
TestSQLContext.sparkContext.parallelize(TableName("test") :: Nil).registerAsTable("tableName")
+
+ val unparsedStrings =
+ TestSQLContext.sparkContext.parallelize(
+ "1, A1, true, null" ::
+ "2, B2, false, null" ::
+ "3, C3, true, null" ::
+ "4, D4, true, 2147483644" :: Nil)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala
new file mode 100644
index 0000000000..46de6fe239
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java
+
+import org.apache.spark.sql.types.util.DataTypeConversions
+import org.scalatest.FunSuite
+
+import org.apache.spark.sql._
+import DataTypeConversions._
+
+class ScalaSideDataTypeConversionSuite extends FunSuite {
+
+ def checkDataType(scalaDataType: DataType) {
+ val javaDataType = asJavaDataType(scalaDataType)
+ val actual = asScalaDataType(javaDataType)
+ assert(scalaDataType === actual, s"Converted data type ${actual} " +
+ s"does not equal the expected data type ${scalaDataType}")
+ }
+
+ test("convert data types") {
+ // Simple DataTypes.
+ checkDataType(StringType)
+ checkDataType(BinaryType)
+ checkDataType(BooleanType)
+ checkDataType(TimestampType)
+ checkDataType(DecimalType)
+ checkDataType(DoubleType)
+ checkDataType(FloatType)
+ checkDataType(ByteType)
+ checkDataType(IntegerType)
+ checkDataType(LongType)
+ checkDataType(ShortType)
+
+ // Simple ArrayType.
+ val simpleScalaArrayType = ArrayType(StringType, true)
+ checkDataType(simpleScalaArrayType)
+
+ // Simple MapType.
+ val simpleScalaMapType = MapType(StringType, LongType)
+ checkDataType(simpleScalaMapType)
+
+ // Simple StructType.
+ val simpleScalaStructType = StructType(
+ StructField("a", DecimalType, false) ::
+ StructField("b", BooleanType, true) ::
+ StructField("c", LongType, true) ::
+ StructField("d", BinaryType, false) :: Nil)
+ checkDataType(simpleScalaStructType)
+
+ // Complex StructType.
+ val complexScalaStructType = StructType(
+ StructField("simpleArray", simpleScalaArrayType, true) ::
+ StructField("simpleMap", simpleScalaMapType, true) ::
+ StructField("simpleStruct", simpleScalaStructType, true) ::
+ StructField("boolean", BooleanType, false) :: Nil)
+ checkDataType(complexScalaStructType)
+
+ // Complex ArrayType.
+ val complexScalaArrayType = ArrayType(complexScalaStructType, true)
+ checkDataType(complexScalaArrayType)
+
+ // Complex MapType.
+ val complexScalaMapType = MapType(complexScalaStructType, complexScalaArrayType, false)
+ checkDataType(complexScalaMapType)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index e765cfc83a..9d9cfdd7c9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -17,16 +17,12 @@
package org.apache.spark.sql.json
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
-import org.apache.spark.sql.catalyst.plans.logical.LeafNode
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType}
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.test.TestSQLContext._
-protected case class Schema(output: Seq[Attribute]) extends LeafNode
-
class JsonSuite extends QueryTest {
import TestJsonData._
TestJsonData
@@ -127,6 +123,18 @@ class JsonSuite extends QueryTest {
checkDataType(ArrayType(IntegerType), ArrayType(LongType), ArrayType(LongType))
checkDataType(ArrayType(IntegerType), ArrayType(StringType), ArrayType(StringType))
checkDataType(ArrayType(IntegerType), StructType(Nil), StringType)
+ checkDataType(
+ ArrayType(IntegerType, true), ArrayType(IntegerType), ArrayType(IntegerType, true))
+ checkDataType(
+ ArrayType(IntegerType, true), ArrayType(IntegerType, false), ArrayType(IntegerType, true))
+ checkDataType(
+ ArrayType(IntegerType, true), ArrayType(IntegerType, true), ArrayType(IntegerType, true))
+ checkDataType(
+ ArrayType(IntegerType, false), ArrayType(IntegerType), ArrayType(IntegerType, false))
+ checkDataType(
+ ArrayType(IntegerType, false), ArrayType(IntegerType, false), ArrayType(IntegerType, false))
+ checkDataType(
+ ArrayType(IntegerType, false), ArrayType(IntegerType, false), ArrayType(IntegerType))
// StructType
checkDataType(StructType(Nil), StructType(Nil), StructType(Nil))
@@ -164,16 +172,16 @@ class JsonSuite extends QueryTest {
test("Primitive field and type inferring") {
val jsonSchemaRDD = jsonRDD(primitiveFieldAndType)
- val expectedSchema =
- AttributeReference("bigInteger", DecimalType, true)() ::
- AttributeReference("boolean", BooleanType, true)() ::
- AttributeReference("double", DoubleType, true)() ::
- AttributeReference("integer", IntegerType, true)() ::
- AttributeReference("long", LongType, true)() ::
- AttributeReference("null", StringType, true)() ::
- AttributeReference("string", StringType, true)() :: Nil
+ val expectedSchema = StructType(
+ StructField("bigInteger", DecimalType, true) ::
+ StructField("boolean", BooleanType, true) ::
+ StructField("double", DoubleType, true) ::
+ StructField("integer", IntegerType, true) ::
+ StructField("long", LongType, true) ::
+ StructField("null", StringType, true) ::
+ StructField("string", StringType, true) :: Nil)
- comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output))
+ assert(expectedSchema === jsonSchemaRDD.schema)
jsonSchemaRDD.registerAsTable("jsonTable")
@@ -192,27 +200,28 @@ class JsonSuite extends QueryTest {
test("Complex field and type inferring") {
val jsonSchemaRDD = jsonRDD(complexFieldAndType)
- val expectedSchema =
- AttributeReference("arrayOfArray1", ArrayType(ArrayType(StringType)), true)() ::
- AttributeReference("arrayOfArray2", ArrayType(ArrayType(DoubleType)), true)() ::
- AttributeReference("arrayOfBigInteger", ArrayType(DecimalType), true)() ::
- AttributeReference("arrayOfBoolean", ArrayType(BooleanType), true)() ::
- AttributeReference("arrayOfDouble", ArrayType(DoubleType), true)() ::
- AttributeReference("arrayOfInteger", ArrayType(IntegerType), true)() ::
- AttributeReference("arrayOfLong", ArrayType(LongType), true)() ::
- AttributeReference("arrayOfNull", ArrayType(StringType), true)() ::
- AttributeReference("arrayOfString", ArrayType(StringType), true)() ::
- AttributeReference("arrayOfStruct", ArrayType(
- StructType(StructField("field1", BooleanType, true) ::
- StructField("field2", StringType, true) :: Nil)), true)() ::
- AttributeReference("struct", StructType(
- StructField("field1", BooleanType, true) ::
- StructField("field2", DecimalType, true) :: Nil), true)() ::
- AttributeReference("structWithArrayFields", StructType(
+ val expectedSchema = StructType(
+ StructField("arrayOfArray1", ArrayType(ArrayType(StringType)), true) ::
+ StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType)), true) ::
+ StructField("arrayOfBigInteger", ArrayType(DecimalType), true) ::
+ StructField("arrayOfBoolean", ArrayType(BooleanType), true) ::
+ StructField("arrayOfDouble", ArrayType(DoubleType), true) ::
+ StructField("arrayOfInteger", ArrayType(IntegerType), true) ::
+ StructField("arrayOfLong", ArrayType(LongType), true) ::
+ StructField("arrayOfNull", ArrayType(StringType, true), true) ::
+ StructField("arrayOfString", ArrayType(StringType), true) ::
+ StructField("arrayOfStruct", ArrayType(
+ StructType(
+ StructField("field1", BooleanType, true) ::
+ StructField("field2", StringType, true) :: Nil)), true) ::
+ StructField("struct", StructType(
+ StructField("field1", BooleanType, true) ::
+ StructField("field2", DecimalType, true) :: Nil), true) ::
+ StructField("structWithArrayFields", StructType(
StructField("field1", ArrayType(IntegerType), true) ::
- StructField("field2", ArrayType(StringType), true) :: Nil), true)() :: Nil
+ StructField("field2", ArrayType(StringType), true) :: Nil), true) :: Nil)
- comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output))
+ assert(expectedSchema === jsonSchemaRDD.schema)
jsonSchemaRDD.registerAsTable("jsonTable")
@@ -301,15 +310,15 @@ class JsonSuite extends QueryTest {
test("Type conflict in primitive field values") {
val jsonSchemaRDD = jsonRDD(primitiveFieldValueTypeConflict)
- val expectedSchema =
- AttributeReference("num_bool", StringType, true)() ::
- AttributeReference("num_num_1", LongType, true)() ::
- AttributeReference("num_num_2", DecimalType, true)() ::
- AttributeReference("num_num_3", DoubleType, true)() ::
- AttributeReference("num_str", StringType, true)() ::
- AttributeReference("str_bool", StringType, true)() :: Nil
+ val expectedSchema = StructType(
+ StructField("num_bool", StringType, true) ::
+ StructField("num_num_1", LongType, true) ::
+ StructField("num_num_2", DecimalType, true) ::
+ StructField("num_num_3", DoubleType, true) ::
+ StructField("num_str", StringType, true) ::
+ StructField("str_bool", StringType, true) :: Nil)
- comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output))
+ assert(expectedSchema === jsonSchemaRDD.schema)
jsonSchemaRDD.registerAsTable("jsonTable")
@@ -426,15 +435,15 @@ class JsonSuite extends QueryTest {
test("Type conflict in complex field values") {
val jsonSchemaRDD = jsonRDD(complexFieldValueTypeConflict)
- val expectedSchema =
- AttributeReference("array", ArrayType(IntegerType), true)() ::
- AttributeReference("num_struct", StringType, true)() ::
- AttributeReference("str_array", StringType, true)() ::
- AttributeReference("struct", StructType(
- StructField("field", StringType, true) :: Nil), true)() ::
- AttributeReference("struct_array", StringType, true)() :: Nil
+ val expectedSchema = StructType(
+ StructField("array", ArrayType(IntegerType), true) ::
+ StructField("num_struct", StringType, true) ::
+ StructField("str_array", StringType, true) ::
+ StructField("struct", StructType(
+ StructField("field", StringType, true) :: Nil), true) ::
+ StructField("struct_array", StringType, true) :: Nil)
- comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output))
+ assert(expectedSchema === jsonSchemaRDD.schema)
jsonSchemaRDD.registerAsTable("jsonTable")
@@ -450,12 +459,12 @@ class JsonSuite extends QueryTest {
test("Type conflict in array elements") {
val jsonSchemaRDD = jsonRDD(arrayElementTypeConflict)
- val expectedSchema =
- AttributeReference("array1", ArrayType(StringType), true)() ::
- AttributeReference("array2", ArrayType(StructType(
- StructField("field", LongType, true) :: Nil)), true)() :: Nil
+ val expectedSchema = StructType(
+ StructField("array1", ArrayType(StringType, true), true) ::
+ StructField("array2", ArrayType(StructType(
+ StructField("field", LongType, true) :: Nil)), true) :: Nil)
- comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output))
+ assert(expectedSchema === jsonSchemaRDD.schema)
jsonSchemaRDD.registerAsTable("jsonTable")
@@ -475,15 +484,15 @@ class JsonSuite extends QueryTest {
test("Handling missing fields") {
val jsonSchemaRDD = jsonRDD(missingFields)
- val expectedSchema =
- AttributeReference("a", BooleanType, true)() ::
- AttributeReference("b", LongType, true)() ::
- AttributeReference("c", ArrayType(IntegerType), true)() ::
- AttributeReference("d", StructType(
- StructField("field", BooleanType, true) :: Nil), true)() ::
- AttributeReference("e", StringType, true)() :: Nil
+ val expectedSchema = StructType(
+ StructField("a", BooleanType, true) ::
+ StructField("b", LongType, true) ::
+ StructField("c", ArrayType(IntegerType), true) ::
+ StructField("d", StructType(
+ StructField("field", BooleanType, true) :: Nil), true) ::
+ StructField("e", StringType, true) :: Nil)
- comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output))
+ assert(expectedSchema === jsonSchemaRDD.schema)
jsonSchemaRDD.registerAsTable("jsonTable")
}
@@ -494,16 +503,16 @@ class JsonSuite extends QueryTest {
primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path)
val jsonSchemaRDD = jsonFile(path)
- val expectedSchema =
- AttributeReference("bigInteger", DecimalType, true)() ::
- AttributeReference("boolean", BooleanType, true)() ::
- AttributeReference("double", DoubleType, true)() ::
- AttributeReference("integer", IntegerType, true)() ::
- AttributeReference("long", LongType, true)() ::
- AttributeReference("null", StringType, true)() ::
- AttributeReference("string", StringType, true)() :: Nil
+ val expectedSchema = StructType(
+ StructField("bigInteger", DecimalType, true) ::
+ StructField("boolean", BooleanType, true) ::
+ StructField("double", DoubleType, true) ::
+ StructField("integer", IntegerType, true) ::
+ StructField("long", LongType, true) ::
+ StructField("null", StringType, true) ::
+ StructField("string", StringType, true) :: Nil)
- comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output))
+ assert(expectedSchema === jsonSchemaRDD.schema)
jsonSchemaRDD.registerAsTable("jsonTable")
@@ -518,4 +527,53 @@ class JsonSuite extends QueryTest {
"this is a simple string.") :: Nil
)
}
+
+ test("Applying schemas") {
+ val file = getTempFilePath("json")
+ val path = file.toString
+ primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path)
+
+ val schema = StructType(
+ StructField("bigInteger", DecimalType, true) ::
+ StructField("boolean", BooleanType, true) ::
+ StructField("double", DoubleType, true) ::
+ StructField("integer", IntegerType, true) ::
+ StructField("long", LongType, true) ::
+ StructField("null", StringType, true) ::
+ StructField("string", StringType, true) :: Nil)
+
+ val jsonSchemaRDD1 = jsonFile(path, schema)
+
+ assert(schema === jsonSchemaRDD1.schema)
+
+ jsonSchemaRDD1.registerAsTable("jsonTable1")
+
+ checkAnswer(
+ sql("select * from jsonTable1"),
+ (BigDecimal("92233720368547758070"),
+ true,
+ 1.7976931348623157E308,
+ 10,
+ 21474836470L,
+ null,
+ "this is a simple string.") :: Nil
+ )
+
+ val jsonSchemaRDD2 = jsonRDD(primitiveFieldAndType, schema)
+
+ assert(schema === jsonSchemaRDD2.schema)
+
+ jsonSchemaRDD2.registerAsTable("jsonTable2")
+
+ checkAnswer(
+ sql("select * from jsonTable2"),
+ (BigDecimal("92233720368547758070"),
+ true,
+ 1.7976931348623157E308,
+ 10,
+ 21474836470L,
+ null,
+ "this is a simple string.") :: Nil
+ )
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index f0a61270da..b413373345 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -37,7 +37,6 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.{Analyzer, OverrideCatalog}
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.execution.QueryExecutionException
import org.apache.spark.sql.execution.{Command => PhysicalCommand}
import org.apache.spark.sql.hive.execution.DescribeHiveTableCommand
@@ -260,9 +259,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
struct.zip(fields).map {
case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}"""
}.mkString("{", ",", "}")
- case (seq: Seq[_], ArrayType(typ)) =>
+ case (seq: Seq[_], ArrayType(typ, _)) =>
seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]")
- case (map: Map[_,_], MapType(kType, vType)) =>
+ case (map: Map[_,_], MapType(kType, vType, _)) =>
map.map {
case (key, value) =>
toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType))
@@ -279,9 +278,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
struct.zip(fields).map {
case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}"""
}.mkString("{", ",", "}")
- case (seq: Seq[_], ArrayType(typ)) =>
+ case (seq: Seq[_], ArrayType(typ, _)) =>
seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]")
- case (map: Map[_,_], MapType(kType, vType)) =>
+ case (map: Map[_,_], MapType(kType, vType, _)) =>
map.map {
case (key, value) =>
toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType))
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index ad7dc0ecdb..354fcd53f3 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -152,8 +152,9 @@ private[hive] trait HiveInspectors {
}
def toInspector(dataType: DataType): ObjectInspector = dataType match {
- case ArrayType(tpe) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe))
- case MapType(keyType, valueType) =>
+ case ArrayType(tpe, _) =>
+ ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe))
+ case MapType(keyType, valueType, _) =>
ObjectInspectorFactory.getStandardMapObjectInspector(
toInspector(keyType), toInspector(valueType))
case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index dff1d6a4b9..fa4e78439c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -200,7 +200,9 @@ object HiveMetastoreTypes extends RegexParsers {
"varchar\\((\\d+)\\)".r ^^^ StringType
protected lazy val arrayType: Parser[DataType] =
- "array" ~> "<" ~> dataType <~ ">" ^^ ArrayType
+ "array" ~> "<" ~> dataType <~ ">" ^^ {
+ case tpe => ArrayType(tpe)
+ }
protected lazy val mapType: Parser[DataType] =
"map" ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ {
@@ -229,10 +231,10 @@ object HiveMetastoreTypes extends RegexParsers {
}
def toMetastoreType(dt: DataType): String = dt match {
- case ArrayType(elementType) => s"array<${toMetastoreType(elementType)}>"
+ case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>"
case StructType(fields) =>
s"struct<${fields.map(f => s"${f.name}:${toMetastoreType(f.dataType)}").mkString(",")}>"
- case MapType(keyType, valueType) =>
+ case MapType(keyType, valueType, _) =>
s"map<${toMetastoreType(keyType)},${toMetastoreType(valueType)}>"
case StringType => "string"
case FloatType => "float"