aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-06-13 18:22:17 -0700
committerMichael Armbrust <michael@databricks.com>2015-06-13 18:22:17 -0700
commit2d71ba4c8a629deab672869ac8e8b6a4b3aec479 (patch)
tree419aae1b9e462a9ea43eb6fee0ce826f8b1baec3 /sql
parenta138953391975886c88bfe81d4ce6b6dd189cd32 (diff)
downloadspark-2d71ba4c8a629deab672869ac8e8b6a4b3aec479.tar.gz
spark-2d71ba4c8a629deab672869ac8e8b6a4b3aec479.tar.bz2
spark-2d71ba4c8a629deab672869ac8e8b6a4b3aec479.zip
[SPARK-8349] [SQL] Use expression constructors (rather than apply) in FunctionRegistry
Author: Reynold Xin <rxin@databricks.com> Closes #6806 from rxin/gs and squashes the following commits: ed1aebb [Reynold Xin] Fixed style. c7fc3e6 [Reynold Xin] [SPARK-8349][SQL] Use expression constructors (rather than apply) in FunctionRegistry
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala18
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala10
5 files changed, 22 insertions, 32 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 45bcbf73fa..04e306da23 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -158,27 +158,23 @@ object FunctionRegistry {
/** See usage above. */
private def expression[T <: Expression](name: String)
(implicit tag: ClassTag[T]): (String, FunctionBuilder) = {
- // Use the companion class to find apply methods.
- val objectClass = Class.forName(tag.runtimeClass.getName + "$")
- val companionObj = objectClass.getDeclaredField("MODULE$").get(null)
-
- // See if we can find an apply that accepts Seq[Expression]
- val varargApply = Try(objectClass.getDeclaredMethod("apply", classOf[Seq[_]])).toOption
+ // See if we can find a constructor that accepts Seq[Expression]
+ val varargCtor = Try(tag.runtimeClass.getDeclaredConstructor(classOf[Seq[_]])).toOption
val builder = (expressions: Seq[Expression]) => {
- if (varargApply.isDefined) {
+ if (varargCtor.isDefined) {
// If there is an apply method that accepts Seq[Expression], use that one.
- varargApply.get.invoke(companionObj, expressions).asInstanceOf[Expression]
+ varargCtor.get.newInstance(expressions).asInstanceOf[Expression]
} else {
- // Otherwise, find an apply method that matches the number of arguments, and use that.
+ // Otherwise, find an ctor method that matches the number of arguments, and use that.
val params = Seq.fill(expressions.size)(classOf[Expression])
- val f = Try(objectClass.getDeclaredMethod("apply", params : _*)) match {
+ val f = Try(tag.runtimeClass.getDeclaredConstructor(params : _*)) match {
case Success(e) =>
e
case Failure(e) =>
throw new AnalysisException(s"Invalid number of arguments for function $name")
}
- f.invoke(companionObj, expressions : _*).asInstanceOf[Expression]
+ f.newInstance(expressions : _*).asInstanceOf[Expression]
}
}
(name, builder)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 61de34bfa4..7427ca76b5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -27,8 +27,7 @@ import org.apache.spark.sql.types._
/**
* If an expression wants to be exposed in the function registry (so users can call it with
* "name(arguments...)", the concrete implementation must be a case class whose constructor
- * arguments are all Expressions types. In addition, if it needs to support more than one
- * constructor, define those constructors explicitly as apply methods in the companion object.
+ * arguments are all Expressions types.
*
* See [[Substring]] for an example.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
index 7e8033307e..cc34467391 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
@@ -49,12 +49,10 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable {
/** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */
case class Rand(seed: Long) extends RDG(seed) {
override def eval(input: InternalRow): Double = rng.nextDouble()
-}
-object Rand {
- def apply(): Rand = apply(Utils.random.nextLong())
+ def this() = this(Utils.random.nextLong())
- def apply(seed: Expression): Rand = apply(seed match {
+ def this(seed: Expression) = this(seed match {
case IntegerLiteral(s) => s
case _ => throw new AnalysisException("Input argument to rand must be an integer literal.")
})
@@ -63,12 +61,10 @@ object Rand {
/** Generate a random column with i.i.d. gaussian random distribution. */
case class Randn(seed: Long) extends RDG(seed) {
override def eval(input: InternalRow): Double = rng.nextGaussian()
-}
-object Randn {
- def apply(): Randn = apply(Utils.random.nextLong())
+ def this() = this(Utils.random.nextLong())
- def apply(seed: Expression): Randn = apply(seed match {
+ def this(seed: Expression) = this(seed match {
case IntegerLiteral(s) => s
case _ => throw new AnalysisException("Input argument to rand must be an integer literal.")
})
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 8ca8d22bc4..315c63e63c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.regex.Pattern
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
+import org.apache.spark.sql.catalyst.expressions.Substring
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -225,6 +226,10 @@ case class EndsWith(left: Expression, right: Expression)
case class Substring(str: Expression, pos: Expression, len: Expression)
extends Expression with ExpectsInputTypes {
+ def this(str: Expression, pos: Expression) = {
+ this(str, pos, Literal(Integer.MAX_VALUE))
+ }
+
override def foldable: Boolean = str.foldable && pos.foldable && len.foldable
override def nullable: Boolean = str.nullable || pos.nullable || len.nullable
@@ -290,12 +295,6 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
}
}
-object Substring {
- def apply(str: Expression, pos: Expression): Substring = {
- apply(str, pos, Literal(Integer.MAX_VALUE))
- }
-}
-
/**
* A function that return the length of the given string expression.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 36d005d0e1..5964e3dc3d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -344,11 +344,11 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
* @param newArgs the new product arguments.
*/
def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") {
- val defaultCtor =
- getClass.getConstructors
- .find(_.getParameterTypes.size != 0)
- .headOption
- .getOrElse(sys.error(s"No valid constructor for $nodeName"))
+ val ctors = getClass.getConstructors.filter(_.getParameterTypes.size != 0)
+ if (ctors.isEmpty) {
+ sys.error(s"No valid constructor for $nodeName")
+ }
+ val defaultCtor = ctors.maxBy(_.getParameterTypes.size)
try {
CurrentOrigin.withOrigin(origin) {