aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorNathan Howell <nhowell@godaddy.com>2015-09-30 15:33:12 -0700
committerYin Huai <yhuai@databricks.com>2015-09-30 15:33:12 -0700
commit89ea0041ae5a701ce8d211ed08f1f059b7f9c396 (patch)
tree2f0a965865b4eeae698edfe0e00997a53f63b9b3 /sql
parent03cca5dce2cd7618b5c0e33163efb8502415b06e (diff)
downloadspark-89ea0041ae5a701ce8d211ed08f1f059b7f9c396.tar.gz
spark-89ea0041ae5a701ce8d211ed08f1f059b7f9c396.tar.bz2
spark-89ea0041ae5a701ce8d211ed08f1f059b7f9c396.zip
[SPARK-9617] [SQL] Implement json_tuple
This is an implementation of Hive's `json_tuple` function using Jackson Streaming. Author: Nathan Howell <nhowell@godaddy.com> Closes #7946 from NathanHowell/SPARK-9617.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala167
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala114
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala38
4 files changed, 316 insertions, 4 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 11b4866bf2..e6122d92b7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -184,6 +184,7 @@ object FunctionRegistry {
expression[FormatNumber]("format_number"),
expression[GetJsonObject]("get_json_object"),
expression[InitCap]("initcap"),
+ expression[JsonTuple]("json_tuple"),
expression[Lower]("lcase"),
expression[Lower]("lower"),
expression[Length]("length"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
index 23bfa18c94..0770fab0ae 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
@@ -21,8 +21,9 @@ import java.io.{StringWriter, ByteArrayOutputStream}
import com.fasterxml.jackson.core._
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.types.{StringType, DataType}
+import org.apache.spark.sql.types.{StructField, StructType, StringType, DataType}
import org.apache.spark.unsafe.types.UTF8String
import scala.util.parsing.combinator.RegexParsers
@@ -92,8 +93,8 @@ private[this] object JsonPathParser extends RegexParsers {
}
}
-private[this] object GetJsonObject {
- private val jsonFactory = new JsonFactory()
+private[this] object SharedFactory {
+ val jsonFactory = new JsonFactory()
// Enabled for Hive compatibility
jsonFactory.enable(JsonParser.Feature.ALLOW_UNQUOTED_CONTROL_CHARS)
@@ -106,7 +107,7 @@ private[this] object GetJsonObject {
case class GetJsonObject(json: Expression, path: Expression)
extends BinaryExpression with ExpectsInputTypes with CodegenFallback {
- import GetJsonObject._
+ import SharedFactory._
import PathInstruction._
import WriteStyle._
import com.fasterxml.jackson.core.JsonToken._
@@ -307,3 +308,161 @@ case class GetJsonObject(json: Expression, path: Expression)
}
}
}
+
+case class JsonTuple(children: Seq[Expression])
+ extends Expression with CodegenFallback {
+
+ import SharedFactory._
+
+ override def nullable: Boolean = {
+ // a row is always returned
+ false
+ }
+
+ // if processing fails this shared value will be returned
+ @transient private lazy val nullRow: InternalRow =
+ new GenericInternalRow(Array.ofDim[Any](fieldExpressions.length))
+
+ // the json body is the first child
+ @transient private lazy val jsonExpr: Expression = children.head
+
+ // the fields to query are the remaining children
+ @transient private lazy val fieldExpressions: Seq[Expression] = children.tail
+
+ // eagerly evaluate any foldable the field names
+ @transient private lazy val foldableFieldNames: IndexedSeq[String] = {
+ fieldExpressions.map {
+ case expr if expr.foldable => expr.eval().asInstanceOf[UTF8String].toString
+ case _ => null
+ }.toIndexedSeq
+ }
+
+ // and count the number of foldable fields, we'll use this later to optimize evaluation
+ @transient private lazy val constantFields: Int = foldableFieldNames.count(_ != null)
+
+ override lazy val dataType: StructType = {
+ val fields = fieldExpressions.zipWithIndex.map {
+ case (_, idx) => StructField(
+ name = s"c$idx", // mirroring GenericUDTFJSONTuple.initialize
+ dataType = StringType,
+ nullable = true)
+ }
+
+ StructType(fields)
+ }
+
+ override def prettyName: String = "json_tuple"
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (children.length < 2) {
+ TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least two arguments")
+ } else if (children.forall(child => StringType.acceptsType(child.dataType))) {
+ TypeCheckResult.TypeCheckSuccess
+ } else {
+ TypeCheckResult.TypeCheckFailure(s"$prettyName requires that all arguments are strings")
+ }
+ }
+
+ override def eval(input: InternalRow): InternalRow = {
+ val json = jsonExpr.eval(input).asInstanceOf[UTF8String]
+ if (json == null) {
+ return nullRow
+ }
+
+ try {
+ val parser = jsonFactory.createParser(json.getBytes)
+
+ try {
+ parseRow(parser, input)
+ } finally {
+ parser.close()
+ }
+ } catch {
+ case _: JsonProcessingException =>
+ nullRow
+ }
+ }
+
+ private def parseRow(parser: JsonParser, input: InternalRow): InternalRow = {
+ // only objects are supported
+ if (parser.nextToken() != JsonToken.START_OBJECT) {
+ return nullRow
+ }
+
+ // evaluate the field names as String rather than UTF8String to
+ // optimize lookups from the json token, which is also a String
+ val fieldNames = if (constantFields == fieldExpressions.length) {
+ // typically the user will provide the field names as foldable expressions
+ // so we can use the cached copy
+ foldableFieldNames
+ } else if (constantFields == 0) {
+ // none are foldable so all field names need to be evaluated from the input row
+ fieldExpressions.map(_.eval(input).asInstanceOf[UTF8String].toString)
+ } else {
+ // if there is a mix of constant and non-constant expressions
+ // prefer the cached copy when available
+ foldableFieldNames.zip(fieldExpressions).map {
+ case (null, expr) => expr.eval(input).asInstanceOf[UTF8String].toString
+ case (fieldName, _) => fieldName
+ }
+ }
+
+ val row = Array.ofDim[Any](fieldNames.length)
+
+ // start reading through the token stream, looking for any requested field names
+ while (parser.nextToken() != JsonToken.END_OBJECT) {
+ if (parser.getCurrentToken == JsonToken.FIELD_NAME) {
+ // check to see if this field is desired in the output
+ val idx = fieldNames.indexOf(parser.getCurrentName)
+ if (idx >= 0) {
+ // it is, copy the child tree to the correct location in the output row
+ val output = new ByteArrayOutputStream()
+
+ // write the output directly to UTF8 encoded byte array
+ if (parser.nextToken() != JsonToken.VALUE_NULL) {
+ val generator = jsonFactory.createGenerator(output, JsonEncoding.UTF8)
+
+ try {
+ copyCurrentStructure(generator, parser)
+ } finally {
+ generator.close()
+ }
+
+ row(idx) = UTF8String.fromBytes(output.toByteArray)
+ }
+ }
+ }
+
+ // always skip children, it's cheap enough to do even if copyCurrentStructure was called
+ parser.skipChildren()
+ }
+
+ new GenericInternalRow(row)
+ }
+
+ private def copyCurrentStructure(generator: JsonGenerator, parser: JsonParser): Unit = {
+ parser.getCurrentToken match {
+ // if the user requests a string field it needs to be returned without enclosing
+ // quotes which is accomplished via JsonGenerator.writeRaw instead of JsonGenerator.write
+ case JsonToken.VALUE_STRING if parser.hasTextCharacters =>
+ // slight optimization to avoid allocating a String instance, though the characters
+ // still have to be decoded... Jackson doesn't have a way to access the raw bytes
+ generator.writeRaw(parser.getTextCharacters, parser.getTextOffset, parser.getTextLength)
+
+ case JsonToken.VALUE_STRING =>
+ // the normal String case, pass it through to the output without enclosing quotes
+ generator.writeRaw(parser.getText)
+
+ case JsonToken.VALUE_NULL =>
+ // a special case that needs to be handled outside of this method.
+ // if a requested field is null, the result must be null. the easiest
+ // way to achieve this is just by ignoring null tokens entirely
+ throw new IllegalStateException("Do not attempt to copy a null field")
+
+ case _ =>
+ // handle other types including objects, arrays, booleans and numbers
+ generator.copyCurrentStructure(parser)
+ }
+ }
+}
+
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
index 4addbaf0cb..f33125f463 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
@@ -18,6 +18,8 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.unsafe.types.UTF8String
class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val json =
@@ -199,4 +201,116 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
GetJsonObject(NonFoldableLiteral(json), NonFoldableLiteral("$.fb:testid")),
"1234")
}
+
+ val jsonTupleQuery = Literal("f1") ::
+ Literal("f2") ::
+ Literal("f3") ::
+ Literal("f4") ::
+ Literal("f5") ::
+ Nil
+
+ test("json_tuple - hive key 1") {
+ checkEvaluation(
+ JsonTuple(
+ Literal("""{"f1": "value1", "f2": "value2", "f3": 3, "f5": 5.23}""") ::
+ jsonTupleQuery),
+ InternalRow.fromSeq(Seq("value1", "value2", "3", null, "5.23").map(UTF8String.fromString)))
+ }
+
+ test("json_tuple - hive key 2") {
+ checkEvaluation(
+ JsonTuple(
+ Literal("""{"f1": "value12", "f3": "value3", "f2": 2, "f4": 4.01}""") ::
+ jsonTupleQuery),
+ InternalRow.fromSeq(Seq("value12", "2", "value3", "4.01", null).map(UTF8String.fromString)))
+ }
+
+ test("json_tuple - hive key 2 (mix of foldable fields)") {
+ checkEvaluation(
+ JsonTuple(Literal("""{"f1": "value12", "f3": "value3", "f2": 2, "f4": 4.01}""") ::
+ Literal("f1") ::
+ NonFoldableLiteral("f2") ::
+ NonFoldableLiteral("f3") ::
+ Literal("f4") ::
+ Literal("f5") ::
+ Nil),
+ InternalRow.fromSeq(Seq("value12", "2", "value3", "4.01", null).map(UTF8String.fromString)))
+ }
+
+ test("json_tuple - hive key 3") {
+ checkEvaluation(
+ JsonTuple(
+ Literal("""{"f1": "value13", "f4": "value44", "f3": "value33", "f2": 2, "f5": 5.01}""") ::
+ jsonTupleQuery),
+ InternalRow.fromSeq(
+ Seq("value13", "2", "value33", "value44", "5.01").map(UTF8String.fromString)))
+ }
+
+ test("json_tuple - hive key 3 (nonfoldable json)") {
+ checkEvaluation(
+ JsonTuple(
+ NonFoldableLiteral(
+ """{"f1": "value13", "f4": "value44",
+ | "f3": "value33", "f2": 2, "f5": 5.01}""".stripMargin)
+ :: jsonTupleQuery),
+ InternalRow.fromSeq(
+ Seq("value13", "2", "value33", "value44", "5.01").map(UTF8String.fromString)))
+ }
+
+ test("json_tuple - hive key 3 (nonfoldable fields)") {
+ checkEvaluation(
+ JsonTuple(Literal(
+ """{"f1": "value13", "f4": "value44",
+ | "f3": "value33", "f2": 2, "f5": 5.01}""".stripMargin) ::
+ NonFoldableLiteral("f1") ::
+ NonFoldableLiteral("f2") ::
+ NonFoldableLiteral("f3") ::
+ NonFoldableLiteral("f4") ::
+ NonFoldableLiteral("f5") ::
+ Nil),
+ InternalRow.fromSeq(
+ Seq("value13", "2", "value33", "value44", "5.01").map(UTF8String.fromString)))
+ }
+
+ test("json_tuple - hive key 4 - null json") {
+ checkEvaluation(
+ JsonTuple(Literal(null) :: jsonTupleQuery),
+ InternalRow.fromSeq(Seq(null, null, null, null, null)))
+ }
+
+ test("json_tuple - hive key 5 - null and empty fields") {
+ checkEvaluation(
+ JsonTuple(Literal("""{"f1": "", "f5": null}""") :: jsonTupleQuery),
+ InternalRow.fromSeq(Seq(UTF8String.fromString(""), null, null, null, null)))
+ }
+
+ test("json_tuple - hive key 6 - invalid json (array)") {
+ checkEvaluation(
+ JsonTuple(Literal("[invalid JSON string]") :: jsonTupleQuery),
+ InternalRow.fromSeq(Seq(null, null, null, null, null)))
+ }
+
+ test("json_tuple - invalid json (object start only)") {
+ checkEvaluation(
+ JsonTuple(Literal("{") :: jsonTupleQuery),
+ InternalRow.fromSeq(Seq(null, null, null, null, null)))
+ }
+
+ test("json_tuple - invalid json (no object end)") {
+ checkEvaluation(
+ JsonTuple(Literal("""{"foo": "bar"""") :: jsonTupleQuery),
+ InternalRow.fromSeq(Seq(null, null, null, null, null)))
+ }
+
+ test("json_tuple - invalid json (invalid json)") {
+ checkEvaluation(
+ JsonTuple(Literal("\\") :: jsonTupleQuery),
+ InternalRow.fromSeq(Seq(null, null, null, null, null)))
+ }
+
+ test("json_tuple - preserve newlines") {
+ checkEvaluation(
+ JsonTuple(Literal("{\"a\":\"b\nc\"}") :: Literal("a") :: Nil),
+ InternalRow.fromSeq(Seq(UTF8String.fromString("b\nc"))))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
index 045fea82e4..e3531d0d6d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
@@ -29,4 +29,42 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
Row("alice", "5"))
}
+
+ val tuples: Seq[(String, String)] =
+ ("1", """{"f1": "value1", "f2": "value2", "f3": 3, "f5": 5.23}""") ::
+ ("2", """{"f1": "value12", "f3": "value3", "f2": 2, "f4": 4.01}""") ::
+ ("3", """{"f1": "value13", "f4": "value44", "f3": "value33", "f2": 2, "f5": 5.01}""") ::
+ ("4", null) ::
+ ("5", """{"f1": "", "f5": null}""") ::
+ ("6", "[invalid JSON string]") ::
+ Nil
+
+ test("json_tuple select") {
+ val df: DataFrame = tuples.toDF("key", "jstring")
+ val expected = Row("1", Row("value1", "value2", "3", null, "5.23")) ::
+ Row("2", Row("value12", "2", "value3", "4.01", null)) ::
+ Row("3", Row("value13", "2", "value33", "value44", "5.01")) ::
+ Row("4", Row(null, null, null, null, null)) ::
+ Row("5", Row("", null, null, null, null)) ::
+ Row("6", Row(null, null, null, null, null)) ::
+ Nil
+
+ checkAnswer(df.selectExpr("key", "json_tuple(jstring, 'f1', 'f2', 'f3', 'f4', 'f5')"), expected)
+ }
+
+ test("json_tuple filter and group") {
+ val df: DataFrame = tuples.toDF("key", "jstring")
+ val expr = df
+ .selectExpr("json_tuple(jstring, 'f1', 'f2') as jt")
+ .where($"jt.c0".isNotNull)
+ .groupBy($"jt.c1")
+ .count()
+
+ val expected = Row(null, 1) ::
+ Row("2", 2) ::
+ Row("value2", 1) ::
+ Nil
+
+ checkAnswer(expr, expected)
+ }
}