aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala28
3 files changed, 41 insertions, 11 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
index 8161151358..0e22375805 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
@@ -150,10 +150,10 @@ private[sql] object JacksonParser {
private def convertMap(
factory: JsonFactory,
parser: JsonParser,
- valueType: DataType): Map[String, Any] = {
- val builder = Map.newBuilder[String, Any]
+ valueType: DataType): Map[UTF8String, Any] = {
+ val builder = Map.newBuilder[UTF8String, Any]
while (nextUntil(parser, JsonToken.END_OBJECT)) {
- builder += parser.getCurrentName -> convertField(factory, parser, valueType)
+ builder += UTF8String(parser.getCurrentName) -> convertField(factory, parser, valueType)
}
builder.result()
@@ -181,7 +181,7 @@ private[sql] object JacksonParser {
val row = new GenericMutableRow(schema.length)
for (corruptIndex <- schema.getFieldIndex(columnNameOfCorruptRecords)) {
require(schema(corruptIndex).dataType == StringType)
- row.update(corruptIndex, record)
+ row.update(corruptIndex, UTF8String(record))
}
Seq(row)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
index 4c32710a17..037a6d60a2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
@@ -20,18 +20,18 @@ package org.apache.spark.sql.json
import java.sql.Timestamp
import scala.collection.Map
-import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper}
+import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper}
-import com.fasterxml.jackson.core.{JsonGenerator, JsonProcessingException}
+import com.fasterxml.jackson.core.JsonProcessingException
import com.fasterxml.jackson.databind.ObjectMapper
+import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._
-import org.apache.spark.Logging
private[sql] object JsonRDD extends Logging {
@@ -318,7 +318,8 @@ private[sql] object JsonRDD extends Logging {
parsed
} catch {
- case e: JsonProcessingException => Map(columnNameOfCorruptRecords -> record) :: Nil
+ case e: JsonProcessingException =>
+ Map(columnNameOfCorruptRecords -> UTF8String(record)) :: Nil
}
}
})
@@ -422,7 +423,10 @@ private[sql] object JsonRDD extends Logging {
value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType))
case MapType(StringType, valueType, _) =>
val map = value.asInstanceOf[Map[String, Any]]
- map.mapValues(enforceCorrectType(_, valueType)).map(identity)
+ map.map {
+ case (k, v) =>
+ (UTF8String(k), enforceCorrectType(v, valueType))
+ }.map(identity)
case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct)
case DateType => toDate(value)
case TimestampType => toTimestamp(value)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index 6f747e5846..7e6eeba177 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -25,7 +25,6 @@ import org.scalactic.Tolerance._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.util.DateUtils
-import org.apache.spark.sql.functions._
import org.apache.spark.sql.json.InferSchema.compatibleType
import org.apache.spark.sql.sources.LogicalRelation
import org.apache.spark.sql.test.TestSQLContext
@@ -1074,4 +1073,31 @@ class JsonSuite extends QueryTest {
assert(StructType(Seq()) === emptySchema)
}
+ test("SPARK-7565 MapType in JsonRDD") {
+ val useStreaming = getConf(SQLConf.USE_JACKSON_STREAMING_API, "true")
+ val oldColumnNameOfCorruptRecord = TestSQLContext.conf.columnNameOfCorruptRecord
+ TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed")
+
+ val schemaWithSimpleMap = StructType(
+ StructField("map", MapType(StringType, IntegerType, true), false) :: Nil)
+ try{
+ for (useStreaming <- List("true", "false")) {
+ setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming)
+ val temp = Utils.createTempDir().getPath
+
+ val df = read.schema(schemaWithSimpleMap).json(mapType1)
+ df.write.mode("overwrite").parquet(temp)
+ // order of MapType is not defined
+ assert(read.parquet(temp).count() == 5)
+
+ val df2 = read.json(corruptRecords)
+ df2.write.mode("overwrite").parquet(temp)
+ checkAnswer(read.parquet(temp), df2.collect())
+ }
+ } finally {
+ setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming)
+ setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord)
+ }
+ }
+
}