From 89ea0041ae5a701ce8d211ed08f1f059b7f9c396 Mon Sep 17 00:00:00 2001 From: Nathan Howell Date: Wed, 30 Sep 2015 15:33:12 -0700 Subject: [SPARK-9617] [SQL] Implement json_tuple This is an implementation of Hive's `json_tuple` function using Jackson Streaming. Author: Nathan Howell Closes #7946 from NathanHowell/SPARK-9617. --- .../sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/expressions/jsonExpressions.scala | 167 ++++++++++++++++++++- .../expressions/JsonExpressionsSuite.scala | 114 ++++++++++++++ .../org/apache/spark/sql/JsonFunctionsSuite.scala | 38 +++++ 4 files changed, 316 insertions(+), 4 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 11b4866bf2..e6122d92b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -184,6 +184,7 @@ object FunctionRegistry { expression[FormatNumber]("format_number"), expression[GetJsonObject]("get_json_object"), expression[InitCap]("initcap"), + expression[JsonTuple]("json_tuple"), expression[Lower]("lcase"), expression[Lower]("lower"), expression[Length]("length"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 23bfa18c94..0770fab0ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -21,8 +21,9 @@ import java.io.{StringWriter, ByteArrayOutputStream} import com.fasterxml.jackson.core._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.types.{StringType, DataType} +import org.apache.spark.sql.types.{StructField, StructType, StringType, DataType} import org.apache.spark.unsafe.types.UTF8String import scala.util.parsing.combinator.RegexParsers @@ -92,8 +93,8 @@ private[this] object JsonPathParser extends RegexParsers { } } -private[this] object GetJsonObject { - private val jsonFactory = new JsonFactory() +private[this] object SharedFactory { + val jsonFactory = new JsonFactory() // Enabled for Hive compatibility jsonFactory.enable(JsonParser.Feature.ALLOW_UNQUOTED_CONTROL_CHARS) @@ -106,7 +107,7 @@ private[this] object GetJsonObject { case class GetJsonObject(json: Expression, path: Expression) extends BinaryExpression with ExpectsInputTypes with CodegenFallback { - import GetJsonObject._ + import SharedFactory._ import PathInstruction._ import WriteStyle._ import com.fasterxml.jackson.core.JsonToken._ @@ -307,3 +308,161 @@ case class GetJsonObject(json: Expression, path: Expression) } } } + +case class JsonTuple(children: Seq[Expression]) + extends Expression with CodegenFallback { + + import SharedFactory._ + + override def nullable: Boolean = { + // a row is always returned + false + } + + // if processing fails this shared value will be returned + @transient private lazy val nullRow: InternalRow = + new GenericInternalRow(Array.ofDim[Any](fieldExpressions.length)) + + // the json body is the first child + @transient private lazy val jsonExpr: Expression = children.head + + // the fields to query are the remaining children + @transient private lazy val fieldExpressions: Seq[Expression] = children.tail + + // eagerly evaluate any foldable the field names + @transient private lazy val foldableFieldNames: IndexedSeq[String] = { + fieldExpressions.map { + case expr if expr.foldable => expr.eval().asInstanceOf[UTF8String].toString + case _ => null + }.toIndexedSeq + } + + // and count the number of foldable fields, we'll use this later to optimize evaluation + @transient private lazy val constantFields: Int = foldableFieldNames.count(_ != null) + + override lazy val dataType: StructType = { + val fields = fieldExpressions.zipWithIndex.map { + case (_, idx) => StructField( + name = s"c$idx", // mirroring GenericUDTFJSONTuple.initialize + dataType = StringType, + nullable = true) + } + + StructType(fields) + } + + override def prettyName: String = "json_tuple" + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.length < 2) { + TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least two arguments") + } else if (children.forall(child => StringType.acceptsType(child.dataType))) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"$prettyName requires that all arguments are strings") + } + } + + override def eval(input: InternalRow): InternalRow = { + val json = jsonExpr.eval(input).asInstanceOf[UTF8String] + if (json == null) { + return nullRow + } + + try { + val parser = jsonFactory.createParser(json.getBytes) + + try { + parseRow(parser, input) + } finally { + parser.close() + } + } catch { + case _: JsonProcessingException => + nullRow + } + } + + private def parseRow(parser: JsonParser, input: InternalRow): InternalRow = { + // only objects are supported + if (parser.nextToken() != JsonToken.START_OBJECT) { + return nullRow + } + + // evaluate the field names as String rather than UTF8String to + // optimize lookups from the json token, which is also a String + val fieldNames = if (constantFields == fieldExpressions.length) { + // typically the user will provide the field names as foldable expressions + // so we can use the cached copy + foldableFieldNames + } else if (constantFields == 0) { + // none are foldable so all field names need to be evaluated from the input row + fieldExpressions.map(_.eval(input).asInstanceOf[UTF8String].toString) + } else { + // if there is a mix of constant and non-constant expressions + // prefer the cached copy when available + foldableFieldNames.zip(fieldExpressions).map { + case (null, expr) => expr.eval(input).asInstanceOf[UTF8String].toString + case (fieldName, _) => fieldName + } + } + + val row = Array.ofDim[Any](fieldNames.length) + + // start reading through the token stream, looking for any requested field names + while (parser.nextToken() != JsonToken.END_OBJECT) { + if (parser.getCurrentToken == JsonToken.FIELD_NAME) { + // check to see if this field is desired in the output + val idx = fieldNames.indexOf(parser.getCurrentName) + if (idx >= 0) { + // it is, copy the child tree to the correct location in the output row + val output = new ByteArrayOutputStream() + + // write the output directly to UTF8 encoded byte array + if (parser.nextToken() != JsonToken.VALUE_NULL) { + val generator = jsonFactory.createGenerator(output, JsonEncoding.UTF8) + + try { + copyCurrentStructure(generator, parser) + } finally { + generator.close() + } + + row(idx) = UTF8String.fromBytes(output.toByteArray) + } + } + } + + // always skip children, it's cheap enough to do even if copyCurrentStructure was called + parser.skipChildren() + } + + new GenericInternalRow(row) + } + + private def copyCurrentStructure(generator: JsonGenerator, parser: JsonParser): Unit = { + parser.getCurrentToken match { + // if the user requests a string field it needs to be returned without enclosing + // quotes which is accomplished via JsonGenerator.writeRaw instead of JsonGenerator.write + case JsonToken.VALUE_STRING if parser.hasTextCharacters => + // slight optimization to avoid allocating a String instance, though the characters + // still have to be decoded... Jackson doesn't have a way to access the raw bytes + generator.writeRaw(parser.getTextCharacters, parser.getTextOffset, parser.getTextLength) + + case JsonToken.VALUE_STRING => + // the normal String case, pass it through to the output without enclosing quotes + generator.writeRaw(parser.getText) + + case JsonToken.VALUE_NULL => + // a special case that needs to be handled outside of this method. + // if a requested field is null, the result must be null. the easiest + // way to achieve this is just by ignoring null tokens entirely + throw new IllegalStateException("Do not attempt to copy a null field") + + case _ => + // handle other types including objects, arrays, booleans and numbers + generator.copyCurrentStructure(parser) + } + } +} + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 4addbaf0cb..f33125f463 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.unsafe.types.UTF8String class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val json = @@ -199,4 +201,116 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { GetJsonObject(NonFoldableLiteral(json), NonFoldableLiteral("$.fb:testid")), "1234") } + + val jsonTupleQuery = Literal("f1") :: + Literal("f2") :: + Literal("f3") :: + Literal("f4") :: + Literal("f5") :: + Nil + + test("json_tuple - hive key 1") { + checkEvaluation( + JsonTuple( + Literal("""{"f1": "value1", "f2": "value2", "f3": 3, "f5": 5.23}""") :: + jsonTupleQuery), + InternalRow.fromSeq(Seq("value1", "value2", "3", null, "5.23").map(UTF8String.fromString))) + } + + test("json_tuple - hive key 2") { + checkEvaluation( + JsonTuple( + Literal("""{"f1": "value12", "f3": "value3", "f2": 2, "f4": 4.01}""") :: + jsonTupleQuery), + InternalRow.fromSeq(Seq("value12", "2", "value3", "4.01", null).map(UTF8String.fromString))) + } + + test("json_tuple - hive key 2 (mix of foldable fields)") { + checkEvaluation( + JsonTuple(Literal("""{"f1": "value12", "f3": "value3", "f2": 2, "f4": 4.01}""") :: + Literal("f1") :: + NonFoldableLiteral("f2") :: + NonFoldableLiteral("f3") :: + Literal("f4") :: + Literal("f5") :: + Nil), + InternalRow.fromSeq(Seq("value12", "2", "value3", "4.01", null).map(UTF8String.fromString))) + } + + test("json_tuple - hive key 3") { + checkEvaluation( + JsonTuple( + Literal("""{"f1": "value13", "f4": "value44", "f3": "value33", "f2": 2, "f5": 5.01}""") :: + jsonTupleQuery), + InternalRow.fromSeq( + Seq("value13", "2", "value33", "value44", "5.01").map(UTF8String.fromString))) + } + + test("json_tuple - hive key 3 (nonfoldable json)") { + checkEvaluation( + JsonTuple( + NonFoldableLiteral( + """{"f1": "value13", "f4": "value44", + | "f3": "value33", "f2": 2, "f5": 5.01}""".stripMargin) + :: jsonTupleQuery), + InternalRow.fromSeq( + Seq("value13", "2", "value33", "value44", "5.01").map(UTF8String.fromString))) + } + + test("json_tuple - hive key 3 (nonfoldable fields)") { + checkEvaluation( + JsonTuple(Literal( + """{"f1": "value13", "f4": "value44", + | "f3": "value33", "f2": 2, "f5": 5.01}""".stripMargin) :: + NonFoldableLiteral("f1") :: + NonFoldableLiteral("f2") :: + NonFoldableLiteral("f3") :: + NonFoldableLiteral("f4") :: + NonFoldableLiteral("f5") :: + Nil), + InternalRow.fromSeq( + Seq("value13", "2", "value33", "value44", "5.01").map(UTF8String.fromString))) + } + + test("json_tuple - hive key 4 - null json") { + checkEvaluation( + JsonTuple(Literal(null) :: jsonTupleQuery), + InternalRow.fromSeq(Seq(null, null, null, null, null))) + } + + test("json_tuple - hive key 5 - null and empty fields") { + checkEvaluation( + JsonTuple(Literal("""{"f1": "", "f5": null}""") :: jsonTupleQuery), + InternalRow.fromSeq(Seq(UTF8String.fromString(""), null, null, null, null))) + } + + test("json_tuple - hive key 6 - invalid json (array)") { + checkEvaluation( + JsonTuple(Literal("[invalid JSON string]") :: jsonTupleQuery), + InternalRow.fromSeq(Seq(null, null, null, null, null))) + } + + test("json_tuple - invalid json (object start only)") { + checkEvaluation( + JsonTuple(Literal("{") :: jsonTupleQuery), + InternalRow.fromSeq(Seq(null, null, null, null, null))) + } + + test("json_tuple - invalid json (no object end)") { + checkEvaluation( + JsonTuple(Literal("""{"foo": "bar"""") :: jsonTupleQuery), + InternalRow.fromSeq(Seq(null, null, null, null, null))) + } + + test("json_tuple - invalid json (invalid json)") { + checkEvaluation( + JsonTuple(Literal("\\") :: jsonTupleQuery), + InternalRow.fromSeq(Seq(null, null, null, null, null))) + } + + test("json_tuple - preserve newlines") { + checkEvaluation( + JsonTuple(Literal("{\"a\":\"b\nc\"}") :: Literal("a") :: Nil), + InternalRow.fromSeq(Seq(UTF8String.fromString("b\nc")))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 045fea82e4..e3531d0d6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -29,4 +29,42 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row("alice", "5")) } + + val tuples: Seq[(String, String)] = + ("1", """{"f1": "value1", "f2": "value2", "f3": 3, "f5": 5.23}""") :: + ("2", """{"f1": "value12", "f3": "value3", "f2": 2, "f4": 4.01}""") :: + ("3", """{"f1": "value13", "f4": "value44", "f3": "value33", "f2": 2, "f5": 5.01}""") :: + ("4", null) :: + ("5", """{"f1": "", "f5": null}""") :: + ("6", "[invalid JSON string]") :: + Nil + + test("json_tuple select") { + val df: DataFrame = tuples.toDF("key", "jstring") + val expected = Row("1", Row("value1", "value2", "3", null, "5.23")) :: + Row("2", Row("value12", "2", "value3", "4.01", null)) :: + Row("3", Row("value13", "2", "value33", "value44", "5.01")) :: + Row("4", Row(null, null, null, null, null)) :: + Row("5", Row("", null, null, null, null)) :: + Row("6", Row(null, null, null, null, null)) :: + Nil + + checkAnswer(df.selectExpr("key", "json_tuple(jstring, 'f1', 'f2', 'f3', 'f4', 'f5')"), expected) + } + + test("json_tuple filter and group") { + val df: DataFrame = tuples.toDF("key", "jstring") + val expr = df + .selectExpr("json_tuple(jstring, 'f1', 'f2') as jt") + .where($"jt.c0".isNotNull) + .groupBy($"jt.c1") + .count() + + val expected = Row(null, 1) :: + Row("2", 2) :: + Row("value2", 1) :: + Nil + + checkAnswer(expr, expected) + } } -- cgit v1.2.3