aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala85
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala15
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala7
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala2
5 files changed, 90 insertions, 25 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index 8d8b5b86d5..54006e20a3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -417,6 +417,12 @@ object StructType extends AbstractDataType {
}
}
+ /**
+ * Creates StructType for a given DDL-formatted string, which is a comma separated list of field
+ * definitions, e.g., a INT, b STRING.
+ */
+ def fromDDL(ddl: String): StructType = CatalystSqlParser.parseTableSchema(ddl)
+
def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray)
def apply(fields: java.util.List[StructField]): StructType = {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index 61e1ec7c7a..05cb999af6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -169,30 +169,72 @@ class DataTypeSuite extends SparkFunSuite {
assert(!arrayType.existsRecursively(_.isInstanceOf[IntegerType]))
}
- def checkDataTypeJsonRepr(dataType: DataType): Unit = {
- test(s"JSON - $dataType") {
+ def checkDataTypeFromJson(dataType: DataType): Unit = {
+ test(s"from Json - $dataType") {
assert(DataType.fromJson(dataType.json) === dataType)
}
}
- checkDataTypeJsonRepr(NullType)
- checkDataTypeJsonRepr(BooleanType)
- checkDataTypeJsonRepr(ByteType)
- checkDataTypeJsonRepr(ShortType)
- checkDataTypeJsonRepr(IntegerType)
- checkDataTypeJsonRepr(LongType)
- checkDataTypeJsonRepr(FloatType)
- checkDataTypeJsonRepr(DoubleType)
- checkDataTypeJsonRepr(DecimalType(10, 5))
- checkDataTypeJsonRepr(DecimalType.SYSTEM_DEFAULT)
- checkDataTypeJsonRepr(DateType)
- checkDataTypeJsonRepr(TimestampType)
- checkDataTypeJsonRepr(StringType)
- checkDataTypeJsonRepr(BinaryType)
- checkDataTypeJsonRepr(ArrayType(DoubleType, true))
- checkDataTypeJsonRepr(ArrayType(StringType, false))
- checkDataTypeJsonRepr(MapType(IntegerType, StringType, true))
- checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false))
+ def checkDataTypeFromDDL(dataType: DataType): Unit = {
+ test(s"from DDL - $dataType") {
+ val parsed = StructType.fromDDL(s"a ${dataType.sql}")
+ val expected = new StructType().add("a", dataType)
+ assert(parsed.sameType(expected))
+ }
+ }
+
+ checkDataTypeFromJson(NullType)
+
+ checkDataTypeFromJson(BooleanType)
+ checkDataTypeFromDDL(BooleanType)
+
+ checkDataTypeFromJson(ByteType)
+ checkDataTypeFromDDL(ByteType)
+
+ checkDataTypeFromJson(ShortType)
+ checkDataTypeFromDDL(ShortType)
+
+ checkDataTypeFromJson(IntegerType)
+ checkDataTypeFromDDL(IntegerType)
+
+ checkDataTypeFromJson(LongType)
+ checkDataTypeFromDDL(LongType)
+
+ checkDataTypeFromJson(FloatType)
+ checkDataTypeFromDDL(FloatType)
+
+ checkDataTypeFromJson(DoubleType)
+ checkDataTypeFromDDL(DoubleType)
+
+ checkDataTypeFromJson(DecimalType(10, 5))
+ checkDataTypeFromDDL(DecimalType(10, 5))
+
+ checkDataTypeFromJson(DecimalType.SYSTEM_DEFAULT)
+ checkDataTypeFromDDL(DecimalType.SYSTEM_DEFAULT)
+
+ checkDataTypeFromJson(DateType)
+ checkDataTypeFromDDL(DateType)
+
+ checkDataTypeFromJson(TimestampType)
+ checkDataTypeFromDDL(TimestampType)
+
+ checkDataTypeFromJson(StringType)
+ checkDataTypeFromDDL(StringType)
+
+ checkDataTypeFromJson(BinaryType)
+ checkDataTypeFromDDL(BinaryType)
+
+ checkDataTypeFromJson(ArrayType(DoubleType, true))
+ checkDataTypeFromDDL(ArrayType(DoubleType, true))
+
+ checkDataTypeFromJson(ArrayType(StringType, false))
+ checkDataTypeFromDDL(ArrayType(StringType, false))
+
+ checkDataTypeFromJson(MapType(IntegerType, StringType, true))
+ checkDataTypeFromDDL(MapType(IntegerType, StringType, true))
+
+ checkDataTypeFromJson(MapType(IntegerType, ArrayType(DoubleType), false))
+ checkDataTypeFromDDL(MapType(IntegerType, ArrayType(DoubleType), false))
val metadata = new MetadataBuilder()
.putString("name", "age")
@@ -201,7 +243,8 @@ class DataTypeSuite extends SparkFunSuite {
StructField("a", IntegerType, nullable = true),
StructField("b", ArrayType(DoubleType), nullable = false),
StructField("c", DoubleType, nullable = false, metadata)))
- checkDataTypeJsonRepr(structType)
+ checkDataTypeFromJson(structType)
+ checkDataTypeFromDDL(structType)
def checkDefaultSize(dataType: DataType, expectedDefaultSize: Int): Unit = {
test(s"Check the default size of $dataType") {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index acdb8e2d3e..0f9203065e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -21,6 +21,7 @@ import scala.collection.JavaConverters._
import scala.language.implicitConversions
import scala.reflect.runtime.universe.{typeTag, TypeTag}
import scala.util.Try
+import scala.util.control.NonFatal
import org.apache.spark.annotation.{Experimental, InterfaceStability}
import org.apache.spark.sql.catalyst.ScalaReflection
@@ -3055,13 +3056,21 @@ object functions {
* with the specified schema. Returns `null`, in the case of an unparseable string.
*
* @param e a string column containing JSON data.
- * @param schema the schema to use when parsing the json string as a json string
+ * @param schema the schema to use when parsing the json string as a json string. In Spark 2.1,
+ * the user-provided schema has to be in JSON format. Since Spark 2.2, the DDL
+ * format is also supported for the schema.
*
* @group collection_funcs
* @since 2.1.0
*/
- def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column =
- from_json(e, DataType.fromJson(schema), options)
+ def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column = {
+ val dataType = try {
+ DataType.fromJson(schema)
+ } catch {
+ case NonFatal(_) => StructType.fromDDL(schema)
+ }
+ from_json(e, dataType, options)
+ }
/**
* (Scala-specific) Converts a column containing a `StructType` or `ArrayType` of `StructType`s
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
index 170c238c53..8465e8d036 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
@@ -156,6 +156,13 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
Row(Seq(Row(1, "a"), Row(2, null), Row(null, null))))
}
+ test("from_json uses DDL strings for defining a schema") {
+ val df = Seq("""{"a": 1, "b": "haa"}""").toDS()
+ checkAnswer(
+ df.select(from_json($"value", "a INT, b STRING", new java.util.HashMap[String, String]())),
+ Row(Row(1, "haa")) :: Nil)
+ }
+
test("to_json - struct") {
val df = Seq(Tuple1(Tuple1(1))).toDF("a")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
index 1607c97cd6..9f4009bfe4 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
@@ -21,7 +21,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
-import org.apache.spark.sql.{sources, Row, SparkSession}
+import org.apache.spark.sql.{sources, SparkSession}
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, GenericInternalRow, InterpretedPredicate, InterpretedProjection, JoinedRow, Literal}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection