aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala17
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala7
3 files changed, 31 insertions, 3 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
index 873221835d..0f27fd13e7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
@@ -287,9 +287,13 @@ private[sql] object JsonRDD extends Logging {
// the ObjectMapper will take the last value associated with this duplicate key.
// For example: for {"key": 1, "key":2}, we will get "key"->2.
val mapper = new ObjectMapper()
- iter.map { record =>
- val parsed = scalafy(mapper.readValue(record, classOf[java.util.Map[String, Any]]))
- parsed.asInstanceOf[Map[String, Any]]
+ iter.flatMap { record =>
+ val parsed = mapper.readValue(record, classOf[Object]) match {
+ case map: java.util.Map[_, _] => scalafy(map).asInstanceOf[Map[String, Any]] :: Nil
+ case list: java.util.List[_] => scalafy(list).asInstanceOf[Seq[Map[String, Any]]]
+ }
+
+ parsed
}
})
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index b50d938554..685e788207 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -622,4 +622,21 @@ class JsonSuite extends QueryTest {
("str1", Nil, "str4", 2) :: Nil
)
}
+
+ test("SPARK-3308 Read top level JSON arrays") {
+ val jsonSchemaRDD = jsonRDD(jsonArray)
+ jsonSchemaRDD.registerTempTable("jsonTable")
+
+ checkAnswer(
+ sql(
+ """
+ |select a, b, c
+ |from jsonTable
+ """.stripMargin),
+ ("str_a_1", null, null) ::
+ ("str_a_2", null, null) ::
+ (null, "str_b_3", null) ::
+ ("str_a_4", "str_b_4", "str_c_4") ::Nil
+ )
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala
index 5f0b3959a6..fc833b8b54 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala
@@ -136,4 +136,11 @@ object TestJsonData {
]
]]
}""" :: Nil)
+
+ val jsonArray =
+ TestSQLContext.sparkContext.parallelize(
+ """[{"a":"str_a_1"}]""" ::
+ """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" ::
+ """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" ::
+ """[]""" :: Nil)
}