aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-01-28 22:42:43 -0800
committerDavies Liu <davies.liu@gmail.com>2016-01-28 22:43:03 -0800
commit721ced28b522cc00b45ca7fa32a99e80ad3de2f7 (patch)
treedeb85e458edc143364c8eeb834a26fe55272c192 /sql
parent66449b8dcdbc3dca126c34b42c4d0419c7648696 (diff)
downloadspark-721ced28b522cc00b45ca7fa32a99e80ad3de2f7.tar.gz
spark-721ced28b522cc00b45ca7fa32a99e80ad3de2f7.tar.bz2
spark-721ced28b522cc00b45ca7fa32a99e80ad3de2f7.zip
[SPARK-13067] [SQL] workaround for a weird scala reflection problem
A simple workaround to avoid getting parameter types when convert a logical plan to json. Author: Wenchen Fan <wenchen@databricks.com> Closes #10970 from cloud-fan/reflection.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala25
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala4
2 files changed, 23 insertions, 6 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 643228d0eb..e5811efb43 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -601,6 +601,20 @@ object ScalaReflection extends ScalaReflection {
getConstructorParameters(t)
}
+ /**
+ * Returns the parameter names for the primary constructor of this class.
+ *
+ * Logically we should call `getConstructorParameters` and throw away the parameter types to get
+ * parameter names, however there are some weird scala reflection problems and this method is a
+ * workaround to avoid getting parameter types.
+ */
+ def getConstructorParameterNames(cls: Class[_]): Seq[String] = {
+ val m = runtimeMirror(cls.getClassLoader)
+ val classSymbol = m.staticClass(cls.getName)
+ val t = classSymbol.selfType
+ constructParams(t).map(_.name.toString)
+ }
+
def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass)
}
@@ -745,6 +759,12 @@ trait ScalaReflection {
def getConstructorParameters(tpe: Type): Seq[(String, Type)] = {
val formalTypeArgs = tpe.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = tpe
+ constructParams(tpe).map { p =>
+ p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
+ }
+ }
+
+ protected def constructParams(tpe: Type): Seq[Symbol] = {
val constructorSymbol = tpe.member(nme.CONSTRUCTOR)
val params = if (constructorSymbol.isMethod) {
constructorSymbol.asMethod.paramss
@@ -758,9 +778,6 @@ trait ScalaReflection {
primaryConstructorSymbol.get.asMethod.paramss
}
}
-
- params.flatten.map { p =>
- p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
- }
+ params.flatten
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 57e1a3c9eb..2df0683f9f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -512,7 +512,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
}
protected def jsonFields: List[JField] = {
- val fieldNames = getConstructorParameters(getClass).map(_._1)
+ val fieldNames = getConstructorParameterNames(getClass)
val fieldValues = productIterator.toSeq ++ otherCopyArgs
assert(fieldNames.length == fieldValues.length, s"${getClass.getSimpleName} fields: " +
fieldNames.mkString(", ") + s", values: " + fieldValues.map(_.toString).mkString(", "))
@@ -560,7 +560,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
case obj if obj.getClass.getName.endsWith("$") => "object" -> obj.getClass.getName
// returns null if the product type doesn't have a primary constructor, e.g. HiveFunctionWrapper
case p: Product => try {
- val fieldNames = getConstructorParameters(p.getClass).map(_._1)
+ val fieldNames = getConstructorParameterNames(p.getClass)
val fieldValues = p.productIterator.toSeq
assert(fieldNames.length == fieldValues.length)
("product-class" -> JString(p.getClass.getName)) :: fieldNames.zip(fieldValues).map {