aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala218
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala294
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala136
3 files changed, 333 insertions, 315 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 893af5146c..83cb375525 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -30,10 +30,15 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext
import org.apache.spark.rdd.{EmptyRDD, RDD}
+import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource}
+import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.ScalaReflection._
import org.apache.spark.sql.catalyst.ScalaReflectionLock
+import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.JoinType
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning}
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
@@ -597,7 +602,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
// this child in all children.
case (name, value: TreeNode[_]) if containsChild(value) =>
name -> JInt(children.indexOf(value))
- case (name, value: Seq[BaseType]) if value.toSet.subsetOf(containsChild) =>
+ case (name, value: Seq[BaseType]) if value.forall(containsChild) =>
name -> JArray(
value.map(v => JInt(children.indexOf(v.asInstanceOf[TreeNode[_]]))).toList
)
@@ -621,194 +626,53 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
// SPARK-17356: In usage of mllib, Metadata may store a huge vector of data, transforming
// it to JSON may trigger OutOfMemoryError.
case m: Metadata => Metadata.empty.jsonValue
+ case clazz: Class[_] => JString(clazz.getName)
case s: StorageLevel =>
("useDisk" -> s.useDisk) ~ ("useMemory" -> s.useMemory) ~ ("useOffHeap" -> s.useOffHeap) ~
("deserialized" -> s.deserialized) ~ ("replication" -> s.replication)
case n: TreeNode[_] => n.jsonValue
case o: Option[_] => o.map(parseToJson)
- case t: Seq[_] => JArray(t.map(parseToJson).toList)
- case m: Map[_, _] =>
- val fields = m.toList.map { case (k: String, v) => (k, parseToJson(v)) }
- JObject(fields)
- case r: RDD[_] => JNothing
+ // Recursive scan Seq[TreeNode], Seq[Partitioning], Seq[DataType]
+ case t: Seq[_] if t.forall(_.isInstanceOf[TreeNode[_]]) ||
+ t.forall(_.isInstanceOf[Partitioning]) || t.forall(_.isInstanceOf[DataType]) =>
+ JArray(t.map(parseToJson).toList)
+ case t: Seq[_] if t.length > 0 && t.head.isInstanceOf[String] =>
+ JString(Utils.truncatedString(t, "[", ", ", "]"))
+ case t: Seq[_] => JNull
+ case m: Map[_, _] => JNull
// if it's a scala object, we can simply keep the full class path.
// TODO: currently if the class name ends with "$", we think it's a scala object, there is
// probably a better way to check it.
case obj if obj.getClass.getName.endsWith("$") => "object" -> obj.getClass.getName
- // returns null if the product type doesn't have a primary constructor, e.g. HiveFunctionWrapper
- case p: Product => try {
- val fieldNames = getConstructorParameterNames(p.getClass)
- val fieldValues = p.productIterator.toSeq
- assert(fieldNames.length == fieldValues.length)
- ("product-class" -> JString(p.getClass.getName)) :: fieldNames.zip(fieldValues).map {
- case (name, value) => name -> parseToJson(value)
- }.toList
- } catch {
- case _: RuntimeException => null
- }
- case _ => JNull
- }
-}
-
-object TreeNode {
- def fromJSON[BaseType <: TreeNode[BaseType]](json: String, sc: SparkContext): BaseType = {
- val jsonAST = parse(json)
- assert(jsonAST.isInstanceOf[JArray])
- reconstruct(jsonAST.asInstanceOf[JArray], sc).asInstanceOf[BaseType]
- }
-
- private def reconstruct(treeNodeJson: JArray, sc: SparkContext): TreeNode[_] = {
- assert(treeNodeJson.arr.forall(_.isInstanceOf[JObject]))
- val jsonNodes = Stack(treeNodeJson.arr.map(_.asInstanceOf[JObject]): _*)
-
- def parseNextNode(): TreeNode[_] = {
- val nextNode = jsonNodes.pop()
-
- val cls = Utils.classForName((nextNode \ "class").asInstanceOf[JString].s)
- if (cls == classOf[Literal]) {
- Literal.fromJSON(nextNode)
- } else if (cls.getName.endsWith("$")) {
- cls.getField("MODULE$").get(cls).asInstanceOf[TreeNode[_]]
- } else {
- val numChildren = (nextNode \ "num-children").asInstanceOf[JInt].num.toInt
-
- val children: Seq[TreeNode[_]] = (1 to numChildren).map(_ => parseNextNode())
- val fields = getConstructorParameters(cls)
-
- val parameters: Array[AnyRef] = fields.map {
- case (fieldName, fieldType) =>
- parseFromJson(nextNode \ fieldName, fieldType, children, sc)
- }.toArray
-
- val maybeCtor = cls.getConstructors.find { p =>
- val expectedTypes = p.getParameterTypes
- expectedTypes.length == fields.length && expectedTypes.zip(fields.map(_._2)).forall {
- case (cls, tpe) => cls == getClassFromType(tpe)
- }
- }
- if (maybeCtor.isEmpty) {
- sys.error(s"No valid constructor for ${cls.getName}")
- } else {
- try {
- maybeCtor.get.newInstance(parameters: _*).asInstanceOf[TreeNode[_]]
- } catch {
- case e: java.lang.IllegalArgumentException =>
- throw new RuntimeException(
- s"""
- |Failed to construct tree node: ${cls.getName}
- |ctor: ${maybeCtor.get}
- |types: ${parameters.map(_.getClass).mkString(", ")}
- |args: ${parameters.mkString(", ")}
- """.stripMargin, e)
- }
- }
- }
- }
-
- parseNextNode()
- }
-
- import universe._
-
- private def parseFromJson(
- value: JValue,
- expectedType: Type,
- children: Seq[TreeNode[_]],
- sc: SparkContext): AnyRef = ScalaReflectionLock.synchronized {
- if (value == JNull) return null
-
- expectedType match {
- case t if t <:< definitions.BooleanTpe =>
- value.asInstanceOf[JBool].value: java.lang.Boolean
- case t if t <:< definitions.ByteTpe =>
- value.asInstanceOf[JInt].num.toByte: java.lang.Byte
- case t if t <:< definitions.ShortTpe =>
- value.asInstanceOf[JInt].num.toShort: java.lang.Short
- case t if t <:< definitions.IntTpe =>
- value.asInstanceOf[JInt].num.toInt: java.lang.Integer
- case t if t <:< definitions.LongTpe =>
- value.asInstanceOf[JInt].num.toLong: java.lang.Long
- case t if t <:< definitions.FloatTpe =>
- value.asInstanceOf[JDouble].num.toFloat: java.lang.Float
- case t if t <:< definitions.DoubleTpe =>
- value.asInstanceOf[JDouble].num: java.lang.Double
-
- case t if t <:< localTypeOf[java.lang.Boolean] =>
- value.asInstanceOf[JBool].value: java.lang.Boolean
- case t if t <:< localTypeOf[BigInt] => value.asInstanceOf[JInt].num
- case t if t <:< localTypeOf[java.lang.String] => value.asInstanceOf[JString].s
- case t if t <:< localTypeOf[UUID] => UUID.fromString(value.asInstanceOf[JString].s)
- case t if t <:< localTypeOf[DataType] => DataType.parseDataType(value)
- case t if t <:< localTypeOf[Metadata] => Metadata.fromJObject(value.asInstanceOf[JObject])
- case t if t <:< localTypeOf[StorageLevel] =>
- val JBool(useDisk) = value \ "useDisk"
- val JBool(useMemory) = value \ "useMemory"
- val JBool(useOffHeap) = value \ "useOffHeap"
- val JBool(deserialized) = value \ "deserialized"
- val JInt(replication) = value \ "replication"
- StorageLevel(useDisk, useMemory, useOffHeap, deserialized, replication.toInt)
- case t if t <:< localTypeOf[TreeNode[_]] => value match {
- case JInt(i) => children(i.toInt)
- case arr: JArray => reconstruct(arr, sc)
- case _ => throw new RuntimeException(s"$value is not a valid json value for tree node.")
+ case p: Product if shouldConvertToJson(p) =>
+ try {
+ val fieldNames = getConstructorParameterNames(p.getClass)
+ val fieldValues = p.productIterator.toSeq
+ assert(fieldNames.length == fieldValues.length)
+ ("product-class" -> JString(p.getClass.getName)) :: fieldNames.zip(fieldValues).map {
+ case (name, value) => name -> parseToJson(value)
+ }.toList
+ } catch {
+ case _: RuntimeException => null
}
- case t if t <:< localTypeOf[Option[_]] =>
- if (value == JNothing) {
- None
- } else {
- val TypeRef(_, _, Seq(optType)) = t
- Option(parseFromJson(value, optType, children, sc))
- }
- case t if t <:< localTypeOf[Seq[_]] =>
- val TypeRef(_, _, Seq(elementType)) = t
- val JArray(elements) = value
- elements.map(parseFromJson(_, elementType, children, sc)).toSeq
- case t if t <:< localTypeOf[Map[_, _]] =>
- val TypeRef(_, _, Seq(keyType, valueType)) = t
- val JObject(fields) = value
- fields.map {
- case (name, value) => name -> parseFromJson(value, valueType, children, sc)
- }.toMap
- case t if t <:< localTypeOf[RDD[_]] =>
- new EmptyRDD[Any](sc)
- case _ if isScalaObject(value) =>
- val JString(clsName) = value \ "object"
- val cls = Utils.classForName(clsName)
- cls.getField("MODULE$").get(cls)
- case t if t <:< localTypeOf[Product] =>
- val fields = getConstructorParameters(t)
- val clsName = getClassNameFromType(t)
- parseToProduct(clsName, fields, value, children, sc)
- // There maybe some cases that the parameter type signature is not Product but the value is,
- // e.g. `SpecifiedWindowFrame` with type signature `WindowFrame`, handle it here.
- case _ if isScalaProduct(value) =>
- val JString(clsName) = value \ "product-class"
- val fields = getConstructorParameters(Utils.classForName(clsName))
- parseToProduct(clsName, fields, value, children, sc)
- case _ => sys.error(s"Do not support type $expectedType with json $value.")
- }
- }
-
- private def parseToProduct(
- clsName: String,
- fields: Seq[(String, Type)],
- value: JValue,
- children: Seq[TreeNode[_]],
- sc: SparkContext): AnyRef = {
- val parameters: Array[AnyRef] = fields.map {
- case (fieldName, fieldType) => parseFromJson(value \ fieldName, fieldType, children, sc)
- }.toArray
- val ctor = Utils.classForName(clsName).getConstructors.maxBy(_.getParameterTypes.size)
- ctor.newInstance(parameters: _*).asInstanceOf[AnyRef]
- }
-
- private def isScalaObject(jValue: JValue): Boolean = (jValue \ "object") match {
- case JString(str) if str.endsWith("$") => true
- case _ => false
+ case _ => JNull
}
- private def isScalaProduct(jValue: JValue): Boolean = (jValue \ "product-class") match {
- case _: JString => true
+ private def shouldConvertToJson(product: Product): Boolean = product match {
+ case exprId: ExprId => true
+ case field: StructField => true
+ case id: TableIdentifier => true
+ case join: JoinType => true
+ case id: FunctionIdentifier => true
+ case spec: BucketSpec => true
+ case catalog: CatalogTable => true
+ case boundary: FrameBoundary => true
+ case frame: WindowFrame => true
+ case partition: Partitioning => true
+ case resource: FunctionResource => true
+ case broadcast: BroadcastMode => true
+ case table: CatalogTableType => true
+ case storage: CatalogStorageFormat => true
case _ => false
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 6246380dbe..cb0426c7a9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -17,13 +17,29 @@
package org.apache.spark.sql.catalyst.trees
+import java.math.BigInteger
+import java.util.UUID
+
import scala.collection.mutable.ArrayBuffer
+import org.json4s.jackson.JsonMethods
+import org.json4s.jackson.JsonMethods._
+import org.json4s.JsonAST._
+import org.json4s.JsonDSL._
+
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource, JarResource}
+import org.apache.spark.sql.catalyst.dsl.expressions.DslString
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.types.{IntegerType, NullType, StringType}
+import org.apache.spark.sql.catalyst.FunctionIdentifier
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.plans.{LeftOuter, NaturalJoin}
+import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Union}
+import org.apache.spark.sql.catalyst.plans.physical.{IdentityBroadcastMode, RoundRobinPartitioning, SinglePartition}
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.types.{BooleanType, DoubleType, FloatType, IntegerType, Metadata, NullType, StringType, StructField, StructType}
+import org.apache.spark.storage.StorageLevel
case class Dummy(optKey: Option[Expression]) extends Expression with CodegenFallback {
override def children: Seq[Expression] = optKey.toSeq
@@ -45,6 +61,20 @@ case class ExpressionInMap(map: Map[String, Expression]) extends Expression with
override lazy val resolved = true
}
+case class JsonTestTreeNode(arg: Any) extends LeafNode {
+ override def output: Seq[Attribute] = Seq.empty[Attribute]
+}
+
+case class NameValue(name: String, value: Any)
+
+case object DummyObject
+
+case class SelfReferenceUDF(
+ var config: Map[String, Any] = Map.empty[String, Any]) extends Function1[String, Boolean] {
+ config += "self" -> this
+ def apply(key: String): Boolean = config.contains(key)
+}
+
class TreeNodeSuite extends SparkFunSuite {
test("top node changed") {
val after = Literal(1) transform { case Literal(1, _) => Literal(2) }
@@ -261,4 +291,264 @@ class TreeNodeSuite extends SparkFunSuite {
assert(actual === expected)
}
}
+
+ test("toJSON") {
+ def assertJSON(input: Any, json: JValue): Unit = {
+ val expected =
+ s"""
+ |[{
+ | "class": "${classOf[JsonTestTreeNode].getName}",
+ | "num-children": 0,
+ | "arg": ${compact(render(json))}
+ |}]
+ """.stripMargin
+ compareJSON(JsonTestTreeNode(input).toJSON, expected)
+ }
+
+ // Converts simple types to JSON
+ assertJSON(true, true)
+ assertJSON(33.toByte, 33)
+ assertJSON(44, 44)
+ assertJSON(55L, 55L)
+ assertJSON(3.0, 3.0)
+ assertJSON(4.0D, 4.0D)
+ assertJSON(BigInt(BigInteger.valueOf(88L)), 88L)
+ assertJSON(null, JNull)
+ assertJSON("text", "text")
+ assertJSON(Some("text"), "text")
+ compareJSON(JsonTestTreeNode(None).toJSON,
+ s"""[
+ | {
+ | "class": "${classOf[JsonTestTreeNode].getName}",
+ | "num-children": 0
+ | }
+ |]
+ """.stripMargin)
+
+ val uuid = UUID.randomUUID()
+ assertJSON(uuid, uuid.toString)
+
+ // Converts Spark Sql DataType to JSON
+ assertJSON(IntegerType, "integer")
+ assertJSON(Metadata.empty, JObject(Nil))
+ assertJSON(
+ StorageLevel.NONE,
+ JObject(
+ "useDisk" -> false,
+ "useMemory" -> false,
+ "useOffHeap" -> false,
+ "deserialized" -> false,
+ "replication" -> 1)
+ )
+
+ // Converts TreeNode argument to JSON
+ assertJSON(
+ Literal(333),
+ List(
+ JObject(
+ "class" -> classOf[Literal].getName,
+ "num-children" -> 0,
+ "value" -> "333",
+ "dataType" -> "integer")))
+
+ // Converts Seq[String] to JSON
+ assertJSON(Seq("1", "2", "3"), "[1, 2, 3]")
+
+ // Converts Seq[DataType] to JSON
+ assertJSON(Seq(IntegerType, DoubleType, FloatType), List("integer", "double", "float"))
+
+ // Converts Seq[Partitioning] to JSON
+ assertJSON(
+ Seq(SinglePartition, RoundRobinPartitioning(numPartitions = 3)),
+ List(
+ JObject("object" -> JString(SinglePartition.getClass.getName)),
+ JObject(
+ "product-class" -> classOf[RoundRobinPartitioning].getName,
+ "numPartitions" -> 3)))
+
+ // Converts case object to JSON
+ assertJSON(DummyObject, JObject("object" -> JString(DummyObject.getClass.getName)))
+
+ // Converts ExprId to JSON
+ assertJSON(
+ ExprId(0, uuid),
+ JObject(
+ "product-class" -> classOf[ExprId].getName,
+ "id" -> 0,
+ "jvmId" -> uuid.toString))
+
+ // Converts StructField to JSON
+ assertJSON(
+ StructField("field", IntegerType),
+ JObject(
+ "product-class" -> classOf[StructField].getName,
+ "name" -> "field",
+ "dataType" -> "integer",
+ "nullable" -> true,
+ "metadata" -> JObject(Nil)))
+
+ // Converts TableIdentifier to JSON
+ assertJSON(
+ TableIdentifier("table"),
+ JObject(
+ "product-class" -> classOf[TableIdentifier].getName,
+ "table" -> "table"))
+
+ // Converts JoinType to JSON
+ assertJSON(
+ NaturalJoin(LeftOuter),
+ JObject(
+ "product-class" -> classOf[NaturalJoin].getName,
+ "tpe" -> JObject("object" -> JString(LeftOuter.getClass.getName))))
+
+ // Converts FunctionIdentifier to JSON
+ assertJSON(
+ FunctionIdentifier("function", None),
+ JObject(
+ "product-class" -> JString(classOf[FunctionIdentifier].getName),
+ "funcName" -> "function"))
+
+ // Converts BucketSpec to JSON
+ assertJSON(
+ BucketSpec(1, Seq("bucket"), Seq("sort")),
+ JObject(
+ "product-class" -> classOf[BucketSpec].getName,
+ "numBuckets" -> 1,
+ "bucketColumnNames" -> "[bucket]",
+ "sortColumnNames" -> "[sort]"))
+
+ // Converts FrameBoundary to JSON
+ assertJSON(
+ ValueFollowing(3),
+ JObject(
+ "product-class" -> classOf[ValueFollowing].getName,
+ "value" -> 3))
+
+ // Converts WindowFrame to JSON
+ assertJSON(
+ SpecifiedWindowFrame(RowFrame, UnboundedFollowing, CurrentRow),
+ JObject(
+ "product-class" -> classOf[SpecifiedWindowFrame].getName,
+ "frameType" -> JObject("object" -> JString(RowFrame.getClass.getName)),
+ "frameStart" -> JObject("object" -> JString(UnboundedFollowing.getClass.getName)),
+ "frameEnd" -> JObject("object" -> JString(CurrentRow.getClass.getName))))
+
+ // Converts Partitioning to JSON
+ assertJSON(
+ RoundRobinPartitioning(numPartitions = 3),
+ JObject(
+ "product-class" -> classOf[RoundRobinPartitioning].getName,
+ "numPartitions" -> 3))
+
+ // Converts FunctionResource to JSON
+ assertJSON(
+ FunctionResource(JarResource, "file:///"),
+ JObject(
+ "product-class" -> JString(classOf[FunctionResource].getName),
+ "resourceType" -> JObject("object" -> JString(JarResource.getClass.getName)),
+ "uri" -> "file:///"))
+
+ // Converts BroadcastMode to JSON
+ assertJSON(
+ IdentityBroadcastMode,
+ JObject("object" -> JString(IdentityBroadcastMode.getClass.getName)))
+
+ // Converts CatalogTable to JSON
+ assertJSON(
+ CatalogTable(
+ TableIdentifier("table"),
+ CatalogTableType.MANAGED,
+ CatalogStorageFormat.empty,
+ StructType(StructField("a", IntegerType, true) :: Nil),
+ createTime = 0L),
+
+ JObject(
+ "product-class" -> classOf[CatalogTable].getName,
+ "identifier" -> JObject(
+ "product-class" -> classOf[TableIdentifier].getName,
+ "table" -> "table"
+ ),
+ "tableType" -> JObject(
+ "product-class" -> classOf[CatalogTableType].getName,
+ "name" -> "MANAGED"
+ ),
+ "storage" -> JObject(
+ "product-class" -> classOf[CatalogStorageFormat].getName,
+ "compressed" -> false,
+ "properties" -> JNull
+ ),
+ "schema" -> JObject(
+ "type" -> "struct",
+ "fields" -> List(
+ JObject(
+ "name" -> "a",
+ "type" -> "integer",
+ "nullable" -> true,
+ "metadata" -> JObject(Nil)))),
+ "partitionColumnNames" -> List.empty[String],
+ "owner" -> "",
+ "createTime" -> 0,
+ "lastAccessTime" -> -1,
+ "properties" -> JNull,
+ "unsupportedFeatures" -> List.empty[String]))
+
+ // For unknown case class, returns JNull.
+ val bigValue = new Array[Int](10000)
+ assertJSON(NameValue("name", bigValue), JNull)
+
+ // Converts Seq[TreeNode] to JSON recursively
+ assertJSON(
+ Seq(Literal(1), Literal(2)),
+ List(
+ List(
+ JObject(
+ "class" -> JString(classOf[Literal].getName),
+ "num-children" -> 0,
+ "value" -> "1",
+ "dataType" -> "integer")),
+ List(
+ JObject(
+ "class" -> JString(classOf[Literal].getName),
+ "num-children" -> 0,
+ "value" -> "2",
+ "dataType" -> "integer"))))
+
+ // Other Seq is converted to JNull, to reduce the risk of out of memory
+ assertJSON(Seq(1, 2, 3), JNull)
+
+ // All Map type is converted to JNull, to reduce the risk of out of memory
+ assertJSON(Map("key" -> "value"), JNull)
+
+ // Unknown type is converted to JNull, to reduce the risk of out of memory
+ assertJSON(new Object {}, JNull)
+
+ // Convert all TreeNode children to JSON
+ assertJSON(
+ Union(Seq(JsonTestTreeNode("0"), JsonTestTreeNode("1"))),
+ List(
+ JObject(
+ "class" -> classOf[Union].getName,
+ "num-children" -> 2,
+ "children" -> List(0, 1)),
+ JObject(
+ "class" -> classOf[JsonTestTreeNode].getName,
+ "num-children" -> 0,
+ "arg" -> "0"),
+ JObject(
+ "class" -> classOf[JsonTestTreeNode].getName,
+ "num-children" -> 0,
+ "arg" -> "1")))
+ }
+
+ test("toJSON should not throws java.lang.StackOverflowError") {
+ val udf = ScalaUDF(SelfReferenceUDF(), BooleanType, Seq("col1".attr))
+ // Should not throw java.lang.StackOverflowError
+ udf.toJSON
+ }
+
+ private def compareJSON(leftJson: String, rightJson: String): Unit = {
+ val left = JsonMethods.parse(leftJson)
+ val right = JsonMethods.parse(rightJson)
+ assert(left == right)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index d361f61764..34fa626e00 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -120,7 +120,6 @@ abstract class QueryTest extends PlanTest {
throw ae
}
}
- checkJsonFormat(analyzedDS)
assertEmptyMissingInput(analyzedDS)
try ds.collect() catch {
@@ -168,8 +167,6 @@ abstract class QueryTest extends PlanTest {
}
}
- checkJsonFormat(analyzedDF)
-
assertEmptyMissingInput(analyzedDF)
QueryTest.checkAnswer(analyzedDF, expectedAnswer) match {
@@ -228,139 +225,6 @@ abstract class QueryTest extends PlanTest {
planWithCaching)
}
- private def checkJsonFormat(ds: Dataset[_]): Unit = {
- // Get the analyzed plan and rewrite the PredicateSubqueries in order to make sure that
- // RDD and Data resolution does not break.
- val logicalPlan = ds.queryExecution.analyzed
-
- // bypass some cases that we can't handle currently.
- logicalPlan.transform {
- case _: ObjectConsumer => return
- case _: ObjectProducer => return
- case _: AppendColumns => return
- case _: TypedFilter => return
- case _: LogicalRelation => return
- case p if p.getClass.getSimpleName == "MetastoreRelation" => return
- case _: MemoryPlan => return
- case p: InMemoryRelation =>
- p.child.transform {
- case _: ObjectConsumerExec => return
- case _: ObjectProducerExec => return
- }
- p
- }.transformAllExpressions {
- case _: ImperativeAggregate => return
- case _: TypedAggregateExpression => return
- case Literal(_, _: ObjectType) => return
- case _: UserDefinedGenerator => return
- }
-
- // bypass hive tests before we fix all corner cases in hive module.
- if (this.getClass.getName.startsWith("org.apache.spark.sql.hive")) return
-
- val jsonString = try {
- logicalPlan.toJSON
- } catch {
- case NonFatal(e) =>
- fail(
- s"""
- |Failed to parse logical plan to JSON:
- |${logicalPlan.treeString}
- """.stripMargin, e)
- }
-
- // scala function is not serializable to JSON, use null to replace them so that we can compare
- // the plans later.
- val normalized1 = logicalPlan.transformAllExpressions {
- case udf: ScalaUDF => udf.copy(function = null)
- case gen: UserDefinedGenerator => gen.copy(function = null)
- // After SPARK-17356: the JSON representation no longer has the Metadata. We need to remove
- // the Metadata from the normalized plan so that we can compare this plan with the
- // JSON-deserialzed plan.
- case a @ Alias(child, name) if a.explicitMetadata.isDefined =>
- Alias(child, name)(a.exprId, a.qualifier, Some(Metadata.empty), a.isGenerated)
- case a: AttributeReference if a.metadata != Metadata.empty =>
- AttributeReference(a.name, a.dataType, a.nullable, Metadata.empty)(a.exprId, a.qualifier,
- a.isGenerated)
- }
-
- // RDDs/data are not serializable to JSON, so we need to collect LogicalPlans that contains
- // these non-serializable stuff, and use these original ones to replace the null-placeholders
- // in the logical plans parsed from JSON.
- val logicalRDDs = new ArrayDeque[LogicalRDD]()
- val localRelations = new ArrayDeque[LocalRelation]()
- val inMemoryRelations = new ArrayDeque[InMemoryRelation]()
- def collectData: (LogicalPlan => Unit) = {
- case l: LogicalRDD =>
- logicalRDDs.offer(l)
- case l: LocalRelation =>
- localRelations.offer(l)
- case i: InMemoryRelation =>
- inMemoryRelations.offer(i)
- case p =>
- p.expressions.foreach {
- _.foreach {
- case s: SubqueryExpression =>
- s.plan.foreach(collectData)
- case _ =>
- }
- }
- }
- logicalPlan.foreach(collectData)
-
-
- val jsonBackPlan = try {
- TreeNode.fromJSON[LogicalPlan](jsonString, spark.sparkContext)
- } catch {
- case NonFatal(e) =>
- fail(
- s"""
- |Failed to rebuild the logical plan from JSON:
- |${logicalPlan.treeString}
- |
- |${logicalPlan.prettyJson}
- """.stripMargin, e)
- }
-
- def renormalize: PartialFunction[LogicalPlan, LogicalPlan] = {
- case l: LogicalRDD =>
- val origin = logicalRDDs.pop()
- LogicalRDD(l.output, origin.rdd)(spark)
- case l: LocalRelation =>
- val origin = localRelations.pop()
- l.copy(data = origin.data)
- case l: InMemoryRelation =>
- val origin = inMemoryRelations.pop()
- InMemoryRelation(
- l.output,
- l.useCompression,
- l.batchSize,
- l.storageLevel,
- origin.child,
- l.tableName)(
- origin.cachedColumnBuffers,
- origin.batchStats)
- case p =>
- p.transformExpressions {
- case s: SubqueryExpression =>
- s.withNewPlan(s.plan.transformDown(renormalize))
- }
- }
- val normalized2 = jsonBackPlan.transformDown(renormalize)
-
- assert(logicalRDDs.isEmpty)
- assert(localRelations.isEmpty)
- assert(inMemoryRelations.isEmpty)
-
- if (normalized1 != normalized2) {
- fail(
- s"""
- |== FAIL: the logical plan parsed from json does not match the original one ===
- |${sideBySide(logicalPlan.treeString, normalized2.treeString).mkString("\n")}
- """.stripMargin)
- }
- }
-
/**
* Asserts that a given [[Dataset]] does not have missing inputs in all the analyzed plans.
*/