aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-12-21 12:47:07 -0800
committerMichael Armbrust <michael@databricks.com>2015-12-21 12:47:07 -0800
commit7634fe9511e1a8fb94979624b1b617b495b48ad3 (patch)
treedf63fbdc4c50a5675540f33e4f1bdc3b00e4d629 /sql
parent474eb21a30f7ee898f76a625a5470c8245af1d22 (diff)
downloadspark-7634fe9511e1a8fb94979624b1b617b495b48ad3.tar.gz
spark-7634fe9511e1a8fb94979624b1b617b495b48ad3.tar.bz2
spark-7634fe9511e1a8fb94979624b1b617b495b48ad3.zip
[SPARK-12321][SQL] JSON format for TreeNode (use reflection)
An alternative solution for https://github.com/apache/spark/pull/10295 , instead of implementing json format for all logical/physical plans and expressions, use reflection to implement it in `TreeNode`. Here I use pre-order traversal to flattern a plan tree to a plan list, and add an extra field `num-children` to each plan node, so that we can reconstruct the tree from the list. example json: logical plan tree: ``` [ { "class" : "org.apache.spark.sql.catalyst.plans.logical.Sort", "num-children" : 1, "order" : [ [ { "class" : "org.apache.spark.sql.catalyst.expressions.SortOrder", "num-children" : 1, "child" : 0, "direction" : "Ascending" }, { "class" : "org.apache.spark.sql.catalyst.expressions.AttributeReference", "num-children" : 0, "name" : "i", "dataType" : "integer", "nullable" : true, "metadata" : { }, "exprId" : { "id" : 10, "jvmId" : "cd1313c7-3f66-4ed7-a320-7d91e4633ac6" }, "qualifiers" : [ ] } ] ], "global" : false, "child" : 0 }, { "class" : "org.apache.spark.sql.catalyst.plans.logical.Project", "num-children" : 1, "projectList" : [ [ { "class" : "org.apache.spark.sql.catalyst.expressions.Alias", "num-children" : 1, "child" : 0, "name" : "i", "exprId" : { "id" : 10, "jvmId" : "cd1313c7-3f66-4ed7-a320-7d91e4633ac6" }, "qualifiers" : [ ] }, { "class" : "org.apache.spark.sql.catalyst.expressions.Add", "num-children" : 2, "left" : 0, "right" : 1 }, { "class" : "org.apache.spark.sql.catalyst.expressions.AttributeReference", "num-children" : 0, "name" : "a", "dataType" : "integer", "nullable" : true, "metadata" : { }, "exprId" : { "id" : 0, "jvmId" : "cd1313c7-3f66-4ed7-a320-7d91e4633ac6" }, "qualifiers" : [ ] }, { "class" : "org.apache.spark.sql.catalyst.expressions.Literal", "num-children" : 0, "value" : "1", "dataType" : "integer" } ], [ { "class" : "org.apache.spark.sql.catalyst.expressions.Alias", "num-children" : 1, "child" : 0, "name" : "j", "exprId" : { "id" : 11, "jvmId" : "cd1313c7-3f66-4ed7-a320-7d91e4633ac6" }, "qualifiers" : [ ] }, { "class" : "org.apache.spark.sql.catalyst.expressions.Multiply", "num-children" : 2, "left" : 0, "right" : 1 }, { "class" : "org.apache.spark.sql.catalyst.expressions.AttributeReference", "num-children" : 0, "name" : "a", "dataType" : "integer", "nullable" : true, "metadata" : { }, "exprId" : { "id" : 0, "jvmId" : "cd1313c7-3f66-4ed7-a320-7d91e4633ac6" }, "qualifiers" : [ ] }, { "class" : "org.apache.spark.sql.catalyst.expressions.Literal", "num-children" : 0, "value" : "2", "dataType" : "integer" } ] ], "child" : 0 }, { "class" : "org.apache.spark.sql.catalyst.plans.logical.LocalRelation", "num-children" : 0, "output" : [ [ { "class" : "org.apache.spark.sql.catalyst.expressions.AttributeReference", "num-children" : 0, "name" : "a", "dataType" : "integer", "nullable" : true, "metadata" : { }, "exprId" : { "id" : 0, "jvmId" : "cd1313c7-3f66-4ed7-a320-7d91e4633ac6" }, "qualifiers" : [ ] } ] ], "data" : [ ] } ] ``` Author: Wenchen Fan <wenchen@databricks.com> Closes #10311 from cloud-fan/toJson-reflection.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala114
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala41
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala258
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala102
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala5
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala2
13 files changed, 472 insertions, 75 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index c1b1d5cd2d..cc9e6af181 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -68,7 +68,7 @@ object ScalaReflection extends ScalaReflection {
val TypeRef(_, _, Seq(elementType)) = tpe
arrayClassFor(elementType)
case other =>
- val clazz = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass)
+ val clazz = getClassFromType(tpe)
ObjectType(clazz)
}
}
@@ -321,29 +321,11 @@ object ScalaReflection extends ScalaReflection {
keyData :: valueData :: Nil)
case t if t <:< localTypeOf[Product] =>
- val formalTypeArgs = t.typeSymbol.asClass.typeParams
- val TypeRef(_, _, actualTypeArgs) = t
- val constructorSymbol = t.member(nme.CONSTRUCTOR)
- val params = if (constructorSymbol.isMethod) {
- constructorSymbol.asMethod.paramss
- } else {
- // Find the primary constructor, and use its parameter ordering.
- val primaryConstructorSymbol: Option[Symbol] =
- constructorSymbol.asTerm.alternatives.find(s =>
- s.isMethod && s.asMethod.isPrimaryConstructor)
+ val params = getConstructorParameters(t)
- if (primaryConstructorSymbol.isEmpty) {
- sys.error("Internal SQL error: Product object did not have a primary constructor.")
- } else {
- primaryConstructorSymbol.get.asMethod.paramss
- }
- }
+ val cls = getClassFromType(tpe)
- val cls = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass)
-
- val arguments = params.head.zipWithIndex.map { case (p, i) =>
- val fieldName = p.name.toString
- val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
+ val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) =>
val dataType = schemaFor(fieldType).dataType
val clsName = getClassNameFromType(fieldType)
val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
@@ -477,27 +459,9 @@ object ScalaReflection extends ScalaReflection {
}
case t if t <:< localTypeOf[Product] =>
- val formalTypeArgs = t.typeSymbol.asClass.typeParams
- val TypeRef(_, _, actualTypeArgs) = t
- val constructorSymbol = t.member(nme.CONSTRUCTOR)
- val params = if (constructorSymbol.isMethod) {
- constructorSymbol.asMethod.paramss
- } else {
- // Find the primary constructor, and use its parameter ordering.
- val primaryConstructorSymbol: Option[Symbol] =
- constructorSymbol.asTerm.alternatives.find(s =>
- s.isMethod && s.asMethod.isPrimaryConstructor)
-
- if (primaryConstructorSymbol.isEmpty) {
- sys.error("Internal SQL error: Product object did not have a primary constructor.")
- } else {
- primaryConstructorSymbol.get.asMethod.paramss
- }
- }
+ val params = getConstructorParameters(t)
- CreateNamedStruct(params.head.flatMap { p =>
- val fieldName = p.name.toString
- val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
+ CreateNamedStruct(params.flatMap { case (fieldName, fieldType) =>
val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
val clsName = getClassNameFromType(fieldType)
val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
@@ -595,6 +559,21 @@ object ScalaReflection extends ScalaReflection {
}
}
}
+
+ /**
+ * Returns the parameter names and types for the primary constructor of this class.
+ *
+ * Note that it only works for scala classes with primary constructor, and currently doesn't
+ * support inner class.
+ */
+ def getConstructorParameters(cls: Class[_]): Seq[(String, Type)] = {
+ val m = runtimeMirror(cls.getClassLoader)
+ val classSymbol = m.staticClass(cls.getName)
+ val t = classSymbol.selfType
+ getConstructorParameters(t)
+ }
+
+ def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass)
}
/**
@@ -668,26 +647,11 @@ trait ScalaReflection {
Schema(MapType(schemaFor(keyType).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
case t if t <:< localTypeOf[Product] =>
- val formalTypeArgs = t.typeSymbol.asClass.typeParams
- val TypeRef(_, _, actualTypeArgs) = t
- val constructorSymbol = t.member(nme.CONSTRUCTOR)
- val params = if (constructorSymbol.isMethod) {
- constructorSymbol.asMethod.paramss
- } else {
- // Find the primary constructor, and use its parameter ordering.
- val primaryConstructorSymbol: Option[Symbol] = constructorSymbol.asTerm.alternatives.find(
- s => s.isMethod && s.asMethod.isPrimaryConstructor)
- if (primaryConstructorSymbol.isEmpty) {
- sys.error("Internal SQL error: Product object did not have a primary constructor.")
- } else {
- primaryConstructorSymbol.get.asMethod.paramss
- }
- }
+ val params = getConstructorParameters(t)
Schema(StructType(
- params.head.map { p =>
- val Schema(dataType, nullable) =
- schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs))
- StructField(p.name.toString, dataType, nullable)
+ params.map { case (fieldName, fieldType) =>
+ val Schema(dataType, nullable) = schemaFor(fieldType)
+ StructField(fieldName, dataType, nullable)
}), nullable = true)
case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true)
case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true)
@@ -740,4 +704,32 @@ trait ScalaReflection {
assert(methods.length == 1)
methods.head.getParameterTypes
}
+
+ /**
+ * Returns the parameter names and types for the primary constructor of this type.
+ *
+ * Note that it only works for scala classes with primary constructor, and currently doesn't
+ * support inner class.
+ */
+ def getConstructorParameters(tpe: Type): Seq[(String, Type)] = {
+ val formalTypeArgs = tpe.typeSymbol.asClass.typeParams
+ val TypeRef(_, _, actualTypeArgs) = tpe
+ val constructorSymbol = tpe.member(nme.CONSTRUCTOR)
+ val params = if (constructorSymbol.isMethod) {
+ constructorSymbol.asMethod.paramss
+ } else {
+ // Find the primary constructor, and use its parameter ordering.
+ val primaryConstructorSymbol: Option[Symbol] = constructorSymbol.asTerm.alternatives.find(
+ s => s.isMethod && s.asMethod.isPrimaryConstructor)
+ if (primaryConstructorSymbol.isEmpty) {
+ sys.error("Internal SQL error: Product object did not have a primary constructor.")
+ } else {
+ primaryConstructorSymbol.get.asMethod.paramss
+ }
+ }
+
+ params.flatten.map { p =>
+ p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
+ }
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index b6d2ddc5b1..b616d6953b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.InternalRow
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 68ec688c99..e3573b4947 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
+import org.json4s.JsonAST._
import java.sql.{Date, Timestamp}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
@@ -55,6 +56,34 @@ object Literal {
*/
def fromObject(obj: AnyRef): Literal = new Literal(obj, ObjectType(obj.getClass))
+ def fromJSON(json: JValue): Literal = {
+ val dataType = DataType.parseDataType(json \ "dataType")
+ json \ "value" match {
+ case JNull => Literal.create(null, dataType)
+ case JString(str) =>
+ val value = dataType match {
+ case BooleanType => str.toBoolean
+ case ByteType => str.toByte
+ case ShortType => str.toShort
+ case IntegerType => str.toInt
+ case LongType => str.toLong
+ case FloatType => str.toFloat
+ case DoubleType => str.toDouble
+ case StringType => UTF8String.fromString(str)
+ case DateType => java.sql.Date.valueOf(str)
+ case TimestampType => java.sql.Timestamp.valueOf(str)
+ case CalendarIntervalType => CalendarInterval.fromString(str)
+ case t: DecimalType =>
+ val d = Decimal(str)
+ assert(d.changePrecision(t.precision, t.scale))
+ d
+ case _ => null
+ }
+ Literal.create(value, dataType)
+ case other => sys.error(s"$other is not a valid Literal json value")
+ }
+ }
+
def create(v: Any, dataType: DataType): Literal = {
Literal(CatalystTypeConverters.convertToCatalyst(v), dataType)
}
@@ -123,6 +152,18 @@ case class Literal protected (value: Any, dataType: DataType)
case _ => false
}
+ override protected def jsonFields: List[JField] = {
+ // Turns all kinds of literal values to string in json field, as the type info is hard to
+ // retain in json format, e.g. {"a": 123} can be a int, or double, or decimal, etc.
+ val jsonValue = (value, dataType) match {
+ case (null, _) => JNull
+ case (i: Int, DateType) => JString(DateTimeUtils.toJavaDate(i).toString)
+ case (l: Long, TimestampType) => JString(DateTimeUtils.toJavaTimestamp(l).toString)
+ case (other, _) => JString(other.toString)
+ }
+ ("value" -> jsonValue) :: ("dataType" -> dataType.jsonValue) :: Nil
+ }
+
override def eval(input: InternalRow): Any = value
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 26b6aca799..eefd9c7482 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -262,6 +262,10 @@ case class AttributeReference(
}
}
+ override protected final def otherCopyArgs: Seq[AnyRef] = {
+ exprId :: qualifiers :: Nil
+ }
+
override def toString: String = s"$name#${exprId.id}$typeSuffix"
// Since the expression id is not in the first constructor it is missing from the default
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index b9db7838db..d2626440b9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -88,6 +88,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
case d: DataType => d // Avoid unpacking Structs
case seq: Traversable[_] => seq.map(recursiveTransform)
case other: AnyRef => other
+ case null => null
}
val newArgs = productIterator.map(recursiveTransform).toArray
@@ -120,6 +121,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
case d: DataType => d // Avoid unpacking Structs
case seq: Traversable[_] => seq.map(recursiveTransform)
case other: AnyRef => other
+ case null => null
}
val newArgs = productIterator.map(recursiveTransform).toArray
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index d838d845d2..c97dc2d8be 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -17,9 +17,25 @@
package org.apache.spark.sql.catalyst.trees
+import java.util.UUID
import scala.collection.Map
-
+import scala.collection.mutable.Stack
+import org.json4s.JsonAST._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.SparkContext
+import org.apache.spark.util.Utils
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.rdd.{EmptyRDD, RDD}
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.catalyst.ScalaReflection._
+import org.apache.spark.sql.catalyst.{TableIdentifier, ScalaReflectionLock}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types.{StructType, DataType}
/** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */
@@ -463,4 +479,244 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
}
s"$nodeName(${args.mkString(",")})"
}
+
+ def toJSON: String = compact(render(jsonValue))
+
+ def prettyJson: String = pretty(render(jsonValue))
+
+ private def jsonValue: JValue = {
+ val jsonValues = scala.collection.mutable.ArrayBuffer.empty[JValue]
+
+ def collectJsonValue(tn: BaseType): Unit = {
+ val jsonFields = ("class" -> JString(tn.getClass.getName)) ::
+ ("num-children" -> JInt(tn.children.length)) :: tn.jsonFields
+ jsonValues += JObject(jsonFields)
+ tn.children.foreach(collectJsonValue)
+ }
+
+ collectJsonValue(this)
+ jsonValues
+ }
+
+ protected def jsonFields: List[JField] = {
+ val fieldNames = getConstructorParameters(getClass).map(_._1)
+ val fieldValues = productIterator.toSeq ++ otherCopyArgs
+ assert(fieldNames.length == fieldValues.length, s"${getClass.getSimpleName} fields: " +
+ fieldNames.mkString(", ") + s", values: " + fieldValues.map(_.toString).mkString(", "))
+
+ fieldNames.zip(fieldValues).map {
+ // If the field value is a child, then use an int to encode it, represents the index of
+ // this child in all children.
+ case (name, value: TreeNode[_]) if containsChild(value) =>
+ name -> JInt(children.indexOf(value))
+ case (name, value: Seq[BaseType]) if value.toSet.subsetOf(containsChild) =>
+ name -> JArray(
+ value.map(v => JInt(children.indexOf(v.asInstanceOf[TreeNode[_]]))).toList
+ )
+ case (name, value) => name -> parseToJson(value)
+ }.toList
+ }
+
+ private def parseToJson(obj: Any): JValue = obj match {
+ case b: Boolean => JBool(b)
+ case b: Byte => JInt(b.toInt)
+ case s: Short => JInt(s.toInt)
+ case i: Int => JInt(i)
+ case l: Long => JInt(l)
+ case f: Float => JDouble(f)
+ case d: Double => JDouble(d)
+ case b: BigInt => JInt(b)
+ case null => JNull
+ case s: String => JString(s)
+ case u: UUID => JString(u.toString)
+ case dt: DataType => dt.jsonValue
+ case m: Metadata => m.jsonValue
+ case s: StorageLevel =>
+ ("useDisk" -> s.useDisk) ~ ("useMemory" -> s.useMemory) ~ ("useOffHeap" -> s.useOffHeap) ~
+ ("deserialized" -> s.deserialized) ~ ("replication" -> s.replication)
+ case n: TreeNode[_] => n.jsonValue
+ case o: Option[_] => o.map(parseToJson)
+ case t: Seq[_] => JArray(t.map(parseToJson).toList)
+ case m: Map[_, _] =>
+ val fields = m.toList.map { case (k: String, v) => (k, parseToJson(v)) }
+ JObject(fields)
+ case r: RDD[_] => JNothing
+ // if it's a scala object, we can simply keep the full class path.
+ // TODO: currently if the class name ends with "$", we think it's a scala object, there is
+ // probably a better way to check it.
+ case obj if obj.getClass.getName.endsWith("$") => "object" -> obj.getClass.getName
+ // returns null if the product type doesn't have a primary constructor, e.g. HiveFunctionWrapper
+ case p: Product => try {
+ val fieldNames = getConstructorParameters(p.getClass).map(_._1)
+ val fieldValues = p.productIterator.toSeq
+ assert(fieldNames.length == fieldValues.length)
+ ("product-class" -> JString(p.getClass.getName)) :: fieldNames.zip(fieldValues).map {
+ case (name, value) => name -> parseToJson(value)
+ }.toList
+ } catch {
+ case _: RuntimeException => null
+ }
+ case _ => JNull
+ }
+}
+
+object TreeNode {
+ def fromJSON[BaseType <: TreeNode[BaseType]](json: String, sc: SparkContext): BaseType = {
+ val jsonAST = parse(json)
+ assert(jsonAST.isInstanceOf[JArray])
+ reconstruct(jsonAST.asInstanceOf[JArray], sc).asInstanceOf[BaseType]
+ }
+
+ private def reconstruct(treeNodeJson: JArray, sc: SparkContext): TreeNode[_] = {
+ assert(treeNodeJson.arr.forall(_.isInstanceOf[JObject]))
+ val jsonNodes = Stack(treeNodeJson.arr.map(_.asInstanceOf[JObject]): _*)
+
+ def parseNextNode(): TreeNode[_] = {
+ val nextNode = jsonNodes.pop()
+
+ val cls = Utils.classForName((nextNode \ "class").asInstanceOf[JString].s)
+ if (cls == classOf[Literal]) {
+ Literal.fromJSON(nextNode)
+ } else if (cls.getName.endsWith("$")) {
+ cls.getField("MODULE$").get(cls).asInstanceOf[TreeNode[_]]
+ } else {
+ val numChildren = (nextNode \ "num-children").asInstanceOf[JInt].num.toInt
+
+ val children: Seq[TreeNode[_]] = (1 to numChildren).map(_ => parseNextNode())
+ val fields = getConstructorParameters(cls)
+
+ val parameters: Array[AnyRef] = fields.map {
+ case (fieldName, fieldType) =>
+ parseFromJson(nextNode \ fieldName, fieldType, children, sc)
+ }.toArray
+
+ val maybeCtor = cls.getConstructors.find { p =>
+ val expectedTypes = p.getParameterTypes
+ expectedTypes.length == fields.length && expectedTypes.zip(fields.map(_._2)).forall {
+ case (cls, tpe) => cls == getClassFromType(tpe)
+ }
+ }
+ if (maybeCtor.isEmpty) {
+ sys.error(s"No valid constructor for ${cls.getName}")
+ } else {
+ try {
+ maybeCtor.get.newInstance(parameters: _*).asInstanceOf[TreeNode[_]]
+ } catch {
+ case e: java.lang.IllegalArgumentException =>
+ throw new RuntimeException(
+ s"""
+ |Failed to construct tree node: ${cls.getName}
+ |ctor: ${maybeCtor.get}
+ |types: ${parameters.map(_.getClass).mkString(", ")}
+ |args: ${parameters.mkString(", ")}
+ """.stripMargin, e)
+ }
+ }
+ }
+ }
+
+ parseNextNode()
+ }
+
+ import universe._
+
+ private def parseFromJson(
+ value: JValue,
+ expectedType: Type,
+ children: Seq[TreeNode[_]],
+ sc: SparkContext): AnyRef = ScalaReflectionLock.synchronized {
+ if (value == JNull) return null
+
+ expectedType match {
+ case t if t <:< definitions.BooleanTpe =>
+ value.asInstanceOf[JBool].value: java.lang.Boolean
+ case t if t <:< definitions.ByteTpe =>
+ value.asInstanceOf[JInt].num.toByte: java.lang.Byte
+ case t if t <:< definitions.ShortTpe =>
+ value.asInstanceOf[JInt].num.toShort: java.lang.Short
+ case t if t <:< definitions.IntTpe =>
+ value.asInstanceOf[JInt].num.toInt: java.lang.Integer
+ case t if t <:< definitions.LongTpe =>
+ value.asInstanceOf[JInt].num.toLong: java.lang.Long
+ case t if t <:< definitions.FloatTpe =>
+ value.asInstanceOf[JDouble].num.toFloat: java.lang.Float
+ case t if t <:< definitions.DoubleTpe =>
+ value.asInstanceOf[JDouble].num: java.lang.Double
+
+ case t if t <:< localTypeOf[BigInt] => value.asInstanceOf[JInt].num
+ case t if t <:< localTypeOf[java.lang.String] => value.asInstanceOf[JString].s
+ case t if t <:< localTypeOf[UUID] => UUID.fromString(value.asInstanceOf[JString].s)
+ case t if t <:< localTypeOf[DataType] => DataType.parseDataType(value)
+ case t if t <:< localTypeOf[Metadata] => Metadata.fromJObject(value.asInstanceOf[JObject])
+ case t if t <:< localTypeOf[StorageLevel] =>
+ val JBool(useDisk) = value \ "useDisk"
+ val JBool(useMemory) = value \ "useMemory"
+ val JBool(useOffHeap) = value \ "useOffHeap"
+ val JBool(deserialized) = value \ "deserialized"
+ val JInt(replication) = value \ "replication"
+ StorageLevel(useDisk, useMemory, useOffHeap, deserialized, replication.toInt)
+ case t if t <:< localTypeOf[TreeNode[_]] => value match {
+ case JInt(i) => children(i.toInt)
+ case arr: JArray => reconstruct(arr, sc)
+ case _ => throw new RuntimeException(s"$value is not a valid json value for tree node.")
+ }
+ case t if t <:< localTypeOf[Option[_]] =>
+ if (value == JNothing) {
+ None
+ } else {
+ val TypeRef(_, _, Seq(optType)) = t
+ Option(parseFromJson(value, optType, children, sc))
+ }
+ case t if t <:< localTypeOf[Seq[_]] =>
+ val TypeRef(_, _, Seq(elementType)) = t
+ val JArray(elements) = value
+ elements.map(parseFromJson(_, elementType, children, sc)).toSeq
+ case t if t <:< localTypeOf[Map[_, _]] =>
+ val TypeRef(_, _, Seq(keyType, valueType)) = t
+ val JObject(fields) = value
+ fields.map {
+ case (name, value) => name -> parseFromJson(value, valueType, children, sc)
+ }.toMap
+ case t if t <:< localTypeOf[RDD[_]] =>
+ new EmptyRDD[Any](sc)
+ case _ if isScalaObject(value) =>
+ val JString(clsName) = value \ "object"
+ val cls = Utils.classForName(clsName)
+ cls.getField("MODULE$").get(cls)
+ case t if t <:< localTypeOf[Product] =>
+ val fields = getConstructorParameters(t)
+ val clsName = getClassNameFromType(t)
+ parseToProduct(clsName, fields, value, children, sc)
+ // There maybe some cases that the parameter type signature is not Product but the value is,
+ // e.g. `SpecifiedWindowFrame` with type signature `WindowFrame`, handle it here.
+ case _ if isScalaProduct(value) =>
+ val JString(clsName) = value \ "product-class"
+ val fields = getConstructorParameters(Utils.classForName(clsName))
+ parseToProduct(clsName, fields, value, children, sc)
+ case _ => sys.error(s"Do not support type $expectedType with json $value.")
+ }
+ }
+
+ private def parseToProduct(
+ clsName: String,
+ fields: Seq[(String, Type)],
+ value: JValue,
+ children: Seq[TreeNode[_]],
+ sc: SparkContext): AnyRef = {
+ val parameters: Array[AnyRef] = fields.map {
+ case (fieldName, fieldType) => parseFromJson(value \ fieldName, fieldType, children, sc)
+ }.toArray
+ val ctor = Utils.classForName(clsName).getConstructors.maxBy(_.getParameterTypes.size)
+ ctor.newInstance(parameters: _*).asInstanceOf[AnyRef]
+ }
+
+ private def isScalaObject(jValue: JValue): Boolean = (jValue \ "object") match {
+ case JString(str) if str.endsWith("$") => true
+ case _ => false
+ }
+
+ private def isScalaProduct(jValue: JValue): Boolean = (jValue \ "product-class") match {
+ case _: JString => true
+ case _ => false
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index b0c43c4100..f8d71c5f02 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -107,8 +107,8 @@ object DataType {
def fromCaseClassString(string: String): DataType = CaseClassStringParser(string)
private val nonDecimalNameToType = {
- Seq(NullType, DateType, TimestampType, BinaryType,
- IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType)
+ Seq(NullType, DateType, TimestampType, BinaryType, IntegerType, BooleanType, LongType,
+ DoubleType, FloatType, ShortType, ByteType, StringType, CalendarIntervalType)
.map(t => t.typeName -> t).toMap
}
@@ -130,7 +130,7 @@ object DataType {
}
// NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side.
- private def parseDataType(json: JValue): DataType = json match {
+ private[sql] def parseDataType(json: JValue): DataType = json match {
case JString(name) =>
nameToType(name)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index b8a4302588..ea5a9afe03 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -74,9 +74,7 @@ private[sql] case class LogicalRDD(
override def children: Seq[LogicalPlan] = Nil
- override protected final def otherCopyArgs: Seq[AnyRef] = {
- sqlContext :: Nil
- }
+ override protected final def otherCopyArgs: Seq[AnyRef] = sqlContext :: Nil
override def newInstance(): LogicalRDD.this.type =
LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
index 3c5a8cb2aa..4afa5f8ec1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
@@ -61,9 +61,9 @@ private[sql] case class InMemoryRelation(
storageLevel: StorageLevel,
@transient child: SparkPlan,
tableName: Option[String])(
- @transient private var _cachedColumnBuffers: RDD[CachedBatch] = null,
- @transient private var _statistics: Statistics = null,
- private var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null)
+ @transient private[sql] var _cachedColumnBuffers: RDD[CachedBatch] = null,
+ @transient private[sql] var _statistics: Statistics = null,
+ private[sql] var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null)
extends LogicalPlan with MultiInstanceRelation {
private val batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] =
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index bc22fb8b7b..9246f55020 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -21,10 +21,15 @@ import java.util.{Locale, TimeZone}
import scala.collection.JavaConverters._
-import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.catalyst.trees.TreeNode
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.columnar.InMemoryRelation
-import org.apache.spark.sql.execution.Queryable
+import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.execution.{LogicalRDD, Queryable}
abstract class QueryTest extends PlanTest {
@@ -123,6 +128,8 @@ abstract class QueryTest extends PlanTest {
|""".stripMargin)
}
+ checkJsonFormat(analyzedDF)
+
QueryTest.checkAnswer(analyzedDF, expectedAnswer) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
@@ -177,6 +184,97 @@ abstract class QueryTest extends PlanTest {
s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" +
planWithCaching)
}
+
+ private def checkJsonFormat(df: DataFrame): Unit = {
+ val logicalPlan = df.queryExecution.analyzed
+ // bypass some cases that we can't handle currently.
+ logicalPlan.transform {
+ case _: MapPartitions[_, _] => return
+ case _: MapGroups[_, _, _] => return
+ case _: AppendColumns[_, _] => return
+ case _: CoGroup[_, _, _, _] => return
+ case _: LogicalRelation => return
+ }.transformAllExpressions {
+ case a: ImperativeAggregate => return
+ }
+
+ val jsonString = try {
+ logicalPlan.toJSON
+ } catch {
+ case e =>
+ fail(
+ s"""
+ |Failed to parse logical plan to JSON:
+ |${logicalPlan.treeString}
+ """.stripMargin, e)
+ }
+
+ // bypass hive tests before we fix all corner cases in hive module.
+ if (this.getClass.getName.startsWith("org.apache.spark.sql.hive")) return
+
+ // scala function is not serializable to JSON, use null to replace them so that we can compare
+ // the plans later.
+ val normalized1 = logicalPlan.transformAllExpressions {
+ case udf: ScalaUDF => udf.copy(function = null)
+ case gen: UserDefinedGenerator => gen.copy(function = null)
+ }
+
+ // RDDs/data are not serializable to JSON, so we need to collect LogicalPlans that contains
+ // these non-serializable stuff, and use these original ones to replace the null-placeholders
+ // in the logical plans parsed from JSON.
+ var logicalRDDs = logicalPlan.collect { case l: LogicalRDD => l }
+ var localRelations = logicalPlan.collect { case l: LocalRelation => l }
+ var inMemoryRelations = logicalPlan.collect { case i: InMemoryRelation => i }
+
+ val jsonBackPlan = try {
+ TreeNode.fromJSON[LogicalPlan](jsonString, sqlContext.sparkContext)
+ } catch {
+ case e =>
+ fail(
+ s"""
+ |Failed to rebuild the logical plan from JSON:
+ |${logicalPlan.treeString}
+ |
+ |${logicalPlan.prettyJson}
+ """.stripMargin, e)
+ }
+
+ val normalized2 = jsonBackPlan transformDown {
+ case l: LogicalRDD =>
+ val origin = logicalRDDs.head
+ logicalRDDs = logicalRDDs.drop(1)
+ LogicalRDD(l.output, origin.rdd)(sqlContext)
+ case l: LocalRelation =>
+ val origin = localRelations.head
+ localRelations = localRelations.drop(1)
+ l.copy(data = origin.data)
+ case l: InMemoryRelation =>
+ val origin = inMemoryRelations.head
+ inMemoryRelations = inMemoryRelations.drop(1)
+ InMemoryRelation(
+ l.output,
+ l.useCompression,
+ l.batchSize,
+ l.storageLevel,
+ origin.child,
+ l.tableName)(
+ origin.cachedColumnBuffers,
+ l._statistics,
+ origin._batchStats)
+ }
+
+ assert(logicalRDDs.isEmpty)
+ assert(localRelations.isEmpty)
+ assert(inMemoryRelations.isEmpty)
+
+ if (normalized1 != normalized2) {
+ fail(
+ s"""
+ |== FAIL: the logical plan parsed from json does not match the original one ===
+ |${sideBySide(logicalPlan.treeString, normalized2.treeString).mkString("\n")}
+ """.stripMargin)
+ }
+ }
}
object QueryTest {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index f602f2fb89..2a1117318a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -65,6 +65,11 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
override def userClass: Class[MyDenseVector] = classOf[MyDenseVector]
private[spark] override def asNullable: MyDenseVectorUDT = this
+
+ override def equals(other: Any): Boolean = other match {
+ case _: MyDenseVectorUDT => true
+ case _ => false
+ }
}
class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 08b291e088..f099e146d1 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -728,6 +728,8 @@ private[hive] case class MetastoreRelation
Objects.hashCode(databaseName, tableName, alias, output)
}
+ override protected def otherCopyArgs: Seq[AnyRef] = table :: sqlContext :: Nil
+
@transient val hiveQlTable: Table = {
// We start by constructing an API table as Hive performs several important transformations
// internally when converting an API table to a QL table.
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index b30117f0de..d9b9ba4bfd 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -58,7 +58,7 @@ case class ScriptTransformation(
ioschema: HiveScriptIOSchema)(@transient private val sc: HiveContext)
extends UnaryNode {
- override def otherCopyArgs: Seq[HiveContext] = sc :: Nil
+ override protected def otherCopyArgs: Seq[HiveContext] = sc :: Nil
private val serializedHiveConf = new SerializableConfiguration(sc.hiveconf)