aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/main/scala/org/apache
diff options
context:
space:
mode:
authorSean Zhong <seanzhong@databricks.com>2016-09-16 19:37:30 +0800
committerWenchen Fan <wenchen@databricks.com>2016-09-16 19:37:30 +0800
commita425a37a5d894e0d7462c8faa81a913495189ece (patch)
tree47d5575d92f084993cacbe4caf87e7e5ed19fdd2 /sql/catalyst/src/main/scala/org/apache
parentfc1efb720c9c0033077c3c20ee144d0f757e6bcd (diff)
downloadspark-a425a37a5d894e0d7462c8faa81a913495189ece.tar.gz
spark-a425a37a5d894e0d7462c8faa81a913495189ece.tar.bz2
spark-a425a37a5d894e0d7462c8faa81a913495189ece.zip
[SPARK-17426][SQL] Refactor `TreeNode.toJSON` to avoid OOM when converting unknown fields to JSON
## What changes were proposed in this pull request? This PR is a follow up of SPARK-17356. Current implementation of `TreeNode.toJSON` recursively converts all fields of TreeNode to JSON, even if the field is of type `Seq` or type Map. This may trigger out of memory exception in cases like: 1. the Seq or Map can be very big. Converting them to JSON may take huge memory, which may trigger out of memory error. 2. Some user space input may also be propagated to the Plan. The user space input can be of arbitrary type, and may also be self-referencing. Trying to print user space input to JSON may trigger out of memory error or stack overflow error. For a code example, please check the Jira description of SPARK-17426. In this PR, we refactor the `TreeNode.toJSON` so that we only convert a field to JSON string if the field is a safe type. ## How was this patch tested? Unit test. Author: Sean Zhong <seanzhong@databricks.com> Closes #14990 from clockfly/json_oom2.
Diffstat (limited to 'sql/catalyst/src/main/scala/org/apache')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala218
1 files changed, 41 insertions, 177 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 893af5146c..83cb375525 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -30,10 +30,15 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext
import org.apache.spark.rdd.{EmptyRDD, RDD}
+import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource}
+import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.ScalaReflection._
import org.apache.spark.sql.catalyst.ScalaReflectionLock
+import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.JoinType
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning}
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
@@ -597,7 +602,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
// this child in all children.
case (name, value: TreeNode[_]) if containsChild(value) =>
name -> JInt(children.indexOf(value))
- case (name, value: Seq[BaseType]) if value.toSet.subsetOf(containsChild) =>
+ case (name, value: Seq[BaseType]) if value.forall(containsChild) =>
name -> JArray(
value.map(v => JInt(children.indexOf(v.asInstanceOf[TreeNode[_]]))).toList
)
@@ -621,194 +626,53 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
// SPARK-17356: In usage of mllib, Metadata may store a huge vector of data, transforming
// it to JSON may trigger OutOfMemoryError.
case m: Metadata => Metadata.empty.jsonValue
+ case clazz: Class[_] => JString(clazz.getName)
case s: StorageLevel =>
("useDisk" -> s.useDisk) ~ ("useMemory" -> s.useMemory) ~ ("useOffHeap" -> s.useOffHeap) ~
("deserialized" -> s.deserialized) ~ ("replication" -> s.replication)
case n: TreeNode[_] => n.jsonValue
case o: Option[_] => o.map(parseToJson)
- case t: Seq[_] => JArray(t.map(parseToJson).toList)
- case m: Map[_, _] =>
- val fields = m.toList.map { case (k: String, v) => (k, parseToJson(v)) }
- JObject(fields)
- case r: RDD[_] => JNothing
+ // Recursive scan Seq[TreeNode], Seq[Partitioning], Seq[DataType]
+ case t: Seq[_] if t.forall(_.isInstanceOf[TreeNode[_]]) ||
+ t.forall(_.isInstanceOf[Partitioning]) || t.forall(_.isInstanceOf[DataType]) =>
+ JArray(t.map(parseToJson).toList)
+ case t: Seq[_] if t.length > 0 && t.head.isInstanceOf[String] =>
+ JString(Utils.truncatedString(t, "[", ", ", "]"))
+ case t: Seq[_] => JNull
+ case m: Map[_, _] => JNull
// if it's a scala object, we can simply keep the full class path.
// TODO: currently if the class name ends with "$", we think it's a scala object, there is
// probably a better way to check it.
case obj if obj.getClass.getName.endsWith("$") => "object" -> obj.getClass.getName
- // returns null if the product type doesn't have a primary constructor, e.g. HiveFunctionWrapper
- case p: Product => try {
- val fieldNames = getConstructorParameterNames(p.getClass)
- val fieldValues = p.productIterator.toSeq
- assert(fieldNames.length == fieldValues.length)
- ("product-class" -> JString(p.getClass.getName)) :: fieldNames.zip(fieldValues).map {
- case (name, value) => name -> parseToJson(value)
- }.toList
- } catch {
- case _: RuntimeException => null
- }
- case _ => JNull
- }
-}
-
-object TreeNode {
- def fromJSON[BaseType <: TreeNode[BaseType]](json: String, sc: SparkContext): BaseType = {
- val jsonAST = parse(json)
- assert(jsonAST.isInstanceOf[JArray])
- reconstruct(jsonAST.asInstanceOf[JArray], sc).asInstanceOf[BaseType]
- }
-
- private def reconstruct(treeNodeJson: JArray, sc: SparkContext): TreeNode[_] = {
- assert(treeNodeJson.arr.forall(_.isInstanceOf[JObject]))
- val jsonNodes = Stack(treeNodeJson.arr.map(_.asInstanceOf[JObject]): _*)
-
- def parseNextNode(): TreeNode[_] = {
- val nextNode = jsonNodes.pop()
-
- val cls = Utils.classForName((nextNode \ "class").asInstanceOf[JString].s)
- if (cls == classOf[Literal]) {
- Literal.fromJSON(nextNode)
- } else if (cls.getName.endsWith("$")) {
- cls.getField("MODULE$").get(cls).asInstanceOf[TreeNode[_]]
- } else {
- val numChildren = (nextNode \ "num-children").asInstanceOf[JInt].num.toInt
-
- val children: Seq[TreeNode[_]] = (1 to numChildren).map(_ => parseNextNode())
- val fields = getConstructorParameters(cls)
-
- val parameters: Array[AnyRef] = fields.map {
- case (fieldName, fieldType) =>
- parseFromJson(nextNode \ fieldName, fieldType, children, sc)
- }.toArray
-
- val maybeCtor = cls.getConstructors.find { p =>
- val expectedTypes = p.getParameterTypes
- expectedTypes.length == fields.length && expectedTypes.zip(fields.map(_._2)).forall {
- case (cls, tpe) => cls == getClassFromType(tpe)
- }
- }
- if (maybeCtor.isEmpty) {
- sys.error(s"No valid constructor for ${cls.getName}")
- } else {
- try {
- maybeCtor.get.newInstance(parameters: _*).asInstanceOf[TreeNode[_]]
- } catch {
- case e: java.lang.IllegalArgumentException =>
- throw new RuntimeException(
- s"""
- |Failed to construct tree node: ${cls.getName}
- |ctor: ${maybeCtor.get}
- |types: ${parameters.map(_.getClass).mkString(", ")}
- |args: ${parameters.mkString(", ")}
- """.stripMargin, e)
- }
- }
- }
- }
-
- parseNextNode()
- }
-
- import universe._
-
- private def parseFromJson(
- value: JValue,
- expectedType: Type,
- children: Seq[TreeNode[_]],
- sc: SparkContext): AnyRef = ScalaReflectionLock.synchronized {
- if (value == JNull) return null
-
- expectedType match {
- case t if t <:< definitions.BooleanTpe =>
- value.asInstanceOf[JBool].value: java.lang.Boolean
- case t if t <:< definitions.ByteTpe =>
- value.asInstanceOf[JInt].num.toByte: java.lang.Byte
- case t if t <:< definitions.ShortTpe =>
- value.asInstanceOf[JInt].num.toShort: java.lang.Short
- case t if t <:< definitions.IntTpe =>
- value.asInstanceOf[JInt].num.toInt: java.lang.Integer
- case t if t <:< definitions.LongTpe =>
- value.asInstanceOf[JInt].num.toLong: java.lang.Long
- case t if t <:< definitions.FloatTpe =>
- value.asInstanceOf[JDouble].num.toFloat: java.lang.Float
- case t if t <:< definitions.DoubleTpe =>
- value.asInstanceOf[JDouble].num: java.lang.Double
-
- case t if t <:< localTypeOf[java.lang.Boolean] =>
- value.asInstanceOf[JBool].value: java.lang.Boolean
- case t if t <:< localTypeOf[BigInt] => value.asInstanceOf[JInt].num
- case t if t <:< localTypeOf[java.lang.String] => value.asInstanceOf[JString].s
- case t if t <:< localTypeOf[UUID] => UUID.fromString(value.asInstanceOf[JString].s)
- case t if t <:< localTypeOf[DataType] => DataType.parseDataType(value)
- case t if t <:< localTypeOf[Metadata] => Metadata.fromJObject(value.asInstanceOf[JObject])
- case t if t <:< localTypeOf[StorageLevel] =>
- val JBool(useDisk) = value \ "useDisk"
- val JBool(useMemory) = value \ "useMemory"
- val JBool(useOffHeap) = value \ "useOffHeap"
- val JBool(deserialized) = value \ "deserialized"
- val JInt(replication) = value \ "replication"
- StorageLevel(useDisk, useMemory, useOffHeap, deserialized, replication.toInt)
- case t if t <:< localTypeOf[TreeNode[_]] => value match {
- case JInt(i) => children(i.toInt)
- case arr: JArray => reconstruct(arr, sc)
- case _ => throw new RuntimeException(s"$value is not a valid json value for tree node.")
+ case p: Product if shouldConvertToJson(p) =>
+ try {
+ val fieldNames = getConstructorParameterNames(p.getClass)
+ val fieldValues = p.productIterator.toSeq
+ assert(fieldNames.length == fieldValues.length)
+ ("product-class" -> JString(p.getClass.getName)) :: fieldNames.zip(fieldValues).map {
+ case (name, value) => name -> parseToJson(value)
+ }.toList
+ } catch {
+ case _: RuntimeException => null
}
- case t if t <:< localTypeOf[Option[_]] =>
- if (value == JNothing) {
- None
- } else {
- val TypeRef(_, _, Seq(optType)) = t
- Option(parseFromJson(value, optType, children, sc))
- }
- case t if t <:< localTypeOf[Seq[_]] =>
- val TypeRef(_, _, Seq(elementType)) = t
- val JArray(elements) = value
- elements.map(parseFromJson(_, elementType, children, sc)).toSeq
- case t if t <:< localTypeOf[Map[_, _]] =>
- val TypeRef(_, _, Seq(keyType, valueType)) = t
- val JObject(fields) = value
- fields.map {
- case (name, value) => name -> parseFromJson(value, valueType, children, sc)
- }.toMap
- case t if t <:< localTypeOf[RDD[_]] =>
- new EmptyRDD[Any](sc)
- case _ if isScalaObject(value) =>
- val JString(clsName) = value \ "object"
- val cls = Utils.classForName(clsName)
- cls.getField("MODULE$").get(cls)
- case t if t <:< localTypeOf[Product] =>
- val fields = getConstructorParameters(t)
- val clsName = getClassNameFromType(t)
- parseToProduct(clsName, fields, value, children, sc)
- // There maybe some cases that the parameter type signature is not Product but the value is,
- // e.g. `SpecifiedWindowFrame` with type signature `WindowFrame`, handle it here.
- case _ if isScalaProduct(value) =>
- val JString(clsName) = value \ "product-class"
- val fields = getConstructorParameters(Utils.classForName(clsName))
- parseToProduct(clsName, fields, value, children, sc)
- case _ => sys.error(s"Do not support type $expectedType with json $value.")
- }
- }
-
- private def parseToProduct(
- clsName: String,
- fields: Seq[(String, Type)],
- value: JValue,
- children: Seq[TreeNode[_]],
- sc: SparkContext): AnyRef = {
- val parameters: Array[AnyRef] = fields.map {
- case (fieldName, fieldType) => parseFromJson(value \ fieldName, fieldType, children, sc)
- }.toArray
- val ctor = Utils.classForName(clsName).getConstructors.maxBy(_.getParameterTypes.size)
- ctor.newInstance(parameters: _*).asInstanceOf[AnyRef]
- }
-
- private def isScalaObject(jValue: JValue): Boolean = (jValue \ "object") match {
- case JString(str) if str.endsWith("$") => true
- case _ => false
+ case _ => JNull
}
- private def isScalaProduct(jValue: JValue): Boolean = (jValue \ "product-class") match {
- case _: JString => true
+ private def shouldConvertToJson(product: Product): Boolean = product match {
+ case exprId: ExprId => true
+ case field: StructField => true
+ case id: TableIdentifier => true
+ case join: JoinType => true
+ case id: FunctionIdentifier => true
+ case spec: BucketSpec => true
+ case catalog: CatalogTable => true
+ case boundary: FrameBoundary => true
+ case frame: WindowFrame => true
+ case partition: Partitioning => true
+ case resource: FunctionResource => true
+ case broadcast: BroadcastMode => true
+ case table: CatalogTableType => true
+ case storage: CatalogStorageFormat => true
case _ => false
}
}