diff options
author | Dan McClary <dan.mcclary@gmail.com> | 2014-11-20 13:36:50 -0800 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-11-20 13:44:19 -0800 |
commit | b8e6886fb8ff8f667fb7e600cd727d8649cad1d1 (patch) | |
tree | 17e5da1690f456271a88667790630f8bb5fffc67 | |
parent | abf29187f0342b607fcefe269391d4db58d2a957 (diff) | |
download | spark-b8e6886fb8ff8f667fb7e600cd727d8649cad1d1.tar.gz spark-b8e6886fb8ff8f667fb7e600cd727d8649cad1d1.tar.bz2 spark-b8e6886fb8ff8f667fb7e600cd727d8649cad1d1.zip |
[SPARK-4228][SQL] SchemaRDD to JSON
Here's a simple fix for SchemaRDD to JSON.
Author: Dan McClary <dan.mcclary@gmail.com>
Closes #3213 from dwmclary/SPARK-4228 and squashes the following commits:
d714e1d [Dan McClary] fixed PEP 8 error
cac2879 [Dan McClary] move pyspark comment and doctest to correct location
f9471d3 [Dan McClary] added pyspark doc and doctest
6598cee [Dan McClary] adding complex type queries
1a5fd30 [Dan McClary] removing SPARK-4228 from SQLQuerySuite
4a651f0 [Dan McClary] cleaned PEP and Scala style failures. Moved tests to JsonSuite
47ceff6 [Dan McClary] cleaned up scala style issues
2ee1e70 [Dan McClary] moved rowToJSON to JsonRDD
4387dd5 [Dan McClary] Added UserDefinedType, cleaned up case formatting
8f7bfb6 [Dan McClary] Map type added to SchemaRDD.toJSON
1b11980 [Dan McClary] Map and UserDefinedTypes partially done
11d2016 [Dan McClary] formatting and unicode deserialization default fixed
6af72d1 [Dan McClary] deleted extaneous comment
4d11c0c [Dan McClary] JsonFactory rewrite of toJSON for SchemaRDD
149dafd [Dan McClary] wrapped scala toJSON in sql.py
5e5eb1b [Dan McClary] switched to Jackson for JSON processing
6c94a54 [Dan McClary] added toJSON to pyspark SchemaRDD
aaeba58 [Dan McClary] added toJSON to pyspark SchemaRDD
1d171aa [Dan McClary] upated missing brace on if statement
319e3ba [Dan McClary] updated to upstream master with merged SPARK-4228
424f130 [Dan McClary] tests pass, ready for pull and PR
626a5b1 [Dan McClary] added toJSON to SchemaRDD
f7d166a [Dan McClary] added toJSON method
5d34e37 [Dan McClary] merge resolved
d6d19e9 [Dan McClary] pr example
5 files changed, 224 insertions, 4 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index e5d62a466c..abb284d1e3 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -45,7 +45,7 @@ from py4j.java_collections import ListConverter, MapConverter from pyspark.rdd import RDD from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \ - CloudPickleSerializer + CloudPickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync @@ -1870,6 +1870,21 @@ class SchemaRDD(RDD): rdd = self._jschema_rdd.baseSchemaRDD().limit(num).toJavaSchemaRDD() return SchemaRDD(rdd, self.sql_ctx) + def toJSON(self, use_unicode=False): + """Convert a SchemaRDD into a MappedRDD of JSON documents; one document per row. + + >>> srdd1 = sqlCtx.jsonRDD(json) + >>> sqlCtx.registerRDDAsTable(srdd1, "table1") + >>> srdd2 = sqlCtx.sql( "SELECT * from table1") + >>> srdd2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}' + True + >>> srdd3 = sqlCtx.sql( "SELECT field3.field4 from table1") + >>> srdd3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}'] + True + """ + rdd = self._jschema_rdd.baseSchemaRDD().toJSON() + return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) + def saveAsParquetFile(self, path): """Save the contents as a Parquet file, preserving the schema. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 904a276ef3..f8970cd3e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -17,17 +17,19 @@ package org.apache.spark.sql -import java.util.{List => JList} - -import org.apache.spark.api.python.SerDeUtil +import java.util.{Map => JMap, List => JList} +import java.io.StringWriter import scala.collection.JavaConversions._ +import com.fasterxml.jackson.core.JsonFactory + import net.razorvine.pickle.Pickler import org.apache.spark.{Dependency, OneToOneDependency, Partition, Partitioner, TaskContext} import org.apache.spark.annotation.{AlphaComponent, Experimental} import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.api.java.JavaSchemaRDD import org.apache.spark.sql.catalyst.ScalaReflection @@ -35,6 +37,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.json.JsonRDD import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython} import org.apache.spark.storage.StorageLevel @@ -131,6 +134,20 @@ class SchemaRDD( */ lazy val schema: StructType = queryExecution.analyzed.schema + /** + * Returns a new RDD with each row transformed to a JSON string. + * + * @group schema + */ + def toJSON: RDD[String] = { + val rowSchema = this.schema + this.mapPartitions { iter => + val jsonFactory = new JsonFactory() + iter.map(JsonRDD.rowToJSON(rowSchema, jsonFactory)) + } + } + + // ======================================================================= // Query DSL // ======================================================================= diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index 78e8d908fe..ac4844f9b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -126,6 +126,12 @@ class JavaSchemaRDD( // Transformations (return a new RDD) /** + * Returns a new RDD with each row transformed to a JSON string. + */ + def toJSON(): JavaRDD[String] = + baseSchemaRDD.toJSON.toJavaRDD + + /** * Return a new RDD that is reduced into `numPartitions` partitions. */ def coalesce(numPartitions: Int, shuffle: Boolean = false): JavaSchemaRDD = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index d9d7a3fea3..ffb9548356 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -20,12 +20,15 @@ package org.apache.spark.sql.json import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.sql.types.util.DataTypeConversions +import java.io.StringWriter + import scala.collection.Map import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} import scala.math.BigDecimal import java.sql.{Date, Timestamp} import com.fasterxml.jackson.core.JsonProcessingException +import com.fasterxml.jackson.core.JsonFactory import com.fasterxml.jackson.databind.ObjectMapper import org.apache.spark.rdd.RDD @@ -424,4 +427,61 @@ private[sql] object JsonRDD extends Logging { row } + + /** Transforms a single Row to JSON using Jackson + * + * @param jsonFactory a JsonFactory object to construct a JsonGenerator + * @param rowSchema the schema object used for conversion + * @param row The row to convert + */ + private[sql] def rowToJSON(rowSchema: StructType, jsonFactory: JsonFactory)(row: Row): String = { + val writer = new StringWriter() + val gen = jsonFactory.createGenerator(writer) + + def valWriter: (DataType, Any) => Unit = { + case (_, null) | (NullType, _) => gen.writeNull() + case (StringType, v: String) => gen.writeString(v) + case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString) + case (IntegerType, v: Int) => gen.writeNumber(v) + case (ShortType, v: Short) => gen.writeNumber(v) + case (FloatType, v: Float) => gen.writeNumber(v) + case (DoubleType, v: Double) => gen.writeNumber(v) + case (LongType, v: Long) => gen.writeNumber(v) + case (DecimalType(), v: scala.math.BigDecimal) => gen.writeNumber(v.bigDecimal) + case (DecimalType(), v: java.math.BigDecimal) => gen.writeNumber(v) + case (ByteType, v: Byte) => gen.writeNumber(v.toInt) + case (BinaryType, v: Array[Byte]) => gen.writeBinary(v) + case (BooleanType, v: Boolean) => gen.writeBoolean(v) + case (DateType, v) => gen.writeString(v.toString) + case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, v) + + case (ArrayType(ty, _), v: Seq[_] ) => + gen.writeStartArray() + v.foreach(valWriter(ty,_)) + gen.writeEndArray() + + case (MapType(kv,vv, _), v: Map[_,_]) => + gen.writeStartObject + v.foreach { p => + gen.writeFieldName(p._1.toString) + valWriter(vv,p._2) + } + gen.writeEndObject + + case (StructType(ty), v: Seq[_]) => + gen.writeStartObject() + ty.zip(v).foreach { + case (_, null) => + case (field, v) => + gen.writeFieldName(field.name) + valWriter(field.dataType, v) + } + gen.writeEndObject() + } + + valWriter(rowSchema, row) + gen.close() + writer.toString + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index f8ca2c773d..f088d41325 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType} import org.apache.spark.sql.{Row, SQLConf, QueryTest} +import org.apache.spark.sql.TestData._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ @@ -779,4 +780,125 @@ class JsonSuite extends QueryTest { Seq(null, null, null, Seq(Seq(null, Seq(1, 2, 3)))) :: Nil ) } + + test("SPARK-4228 SchemaRDD to JSON") + { + val schema1 = StructType( + StructField("f1", IntegerType, false) :: + StructField("f2", StringType, false) :: + StructField("f3", BooleanType, false) :: + StructField("f4", ArrayType(StringType), nullable = true) :: + StructField("f5", IntegerType, true) :: Nil) + + val rowRDD1 = unparsedStrings.map { r => + val values = r.split(",").map(_.trim) + val v5 = try values(3).toInt catch { + case _: NumberFormatException => null + } + Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5) + } + + val schemaRDD1 = applySchema(rowRDD1, schema1) + schemaRDD1.registerTempTable("applySchema1") + val schemaRDD2 = schemaRDD1.toSchemaRDD + val result = schemaRDD2.toJSON.collect() + assert(result(0) == "{\"f1\":1,\"f2\":\"A1\",\"f3\":true,\"f4\":[\"1\",\" A1\",\" true\",\" null\"]}") + assert(result(3) == "{\"f1\":4,\"f2\":\"D4\",\"f3\":true,\"f4\":[\"4\",\" D4\",\" true\",\" 2147483644\"],\"f5\":2147483644}") + + val schema2 = StructType( + StructField("f1", StructType( + StructField("f11", IntegerType, false) :: + StructField("f12", BooleanType, false) :: Nil), false) :: + StructField("f2", MapType(StringType, IntegerType, true), false) :: Nil) + + val rowRDD2 = unparsedStrings.map { r => + val values = r.split(",").map(_.trim) + val v4 = try values(3).toInt catch { + case _: NumberFormatException => null + } + Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) + } + + val schemaRDD3 = applySchema(rowRDD2, schema2) + schemaRDD3.registerTempTable("applySchema2") + val schemaRDD4 = schemaRDD3.toSchemaRDD + val result2 = schemaRDD4.toJSON.collect() + + assert(result2(1) == "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") + assert(result2(3) == "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") + + val jsonSchemaRDD = jsonRDD(primitiveFieldAndType) + val primTable = jsonRDD(jsonSchemaRDD.toJSON) + primTable.registerTempTable("primativeTable") + checkAnswer( + sql("select * from primativeTable"), + (BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + "this is a simple string.") :: Nil + ) + + val complexJsonSchemaRDD = jsonRDD(complexFieldAndType1) + val compTable = jsonRDD(complexJsonSchemaRDD.toJSON) + compTable.registerTempTable("complexTable") + // Access elements of a primitive array. + checkAnswer( + sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from complexTable"), + ("str1", "str2", null) :: Nil + ) + + // Access an array of null values. + checkAnswer( + sql("select arrayOfNull from complexTable"), + Seq(Seq(null, null, null, null)) :: Nil + ) + + // Access elements of a BigInteger array (we use DecimalType internally). + checkAnswer( + sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from complexTable"), + (BigDecimal("922337203685477580700"), BigDecimal("-922337203685477580800"), null) :: Nil + ) + + // Access elements of an array of arrays. + checkAnswer( + sql("select arrayOfArray1[0], arrayOfArray1[1] from complexTable"), + (Seq("1", "2", "3"), Seq("str1", "str2")) :: Nil + ) + + // Access elements of an array of arrays. + checkAnswer( + sql("select arrayOfArray2[0], arrayOfArray2[1] from complexTable"), + (Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) :: Nil + ) + + // Access elements of an array inside a filed with the type of ArrayType(ArrayType). + checkAnswer( + sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from complexTable"), + ("str2", 2.1) :: Nil + ) + + // Access a struct and fields inside of it. + checkAnswer( + sql("select struct, struct.field1, struct.field2 from complexTable"), + Row( + Row(true, BigDecimal("92233720368547758070")), + true, + BigDecimal("92233720368547758070")) :: Nil + ) + + // Access an array field of a struct. + checkAnswer( + sql("select structWithArrayFields.field1, structWithArrayFields.field2 from complexTable"), + (Seq(4, 5, 6), Seq("str1", "str2")) :: Nil + ) + + // Access elements of an array field of a struct. + checkAnswer( + sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from complexTable"), + (5, null) :: Nil + ) + + } } |