From 2d71ba4c8a629deab672869ac8e8b6a4b3aec479 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 13 Jun 2015 18:22:17 -0700 Subject: [SPARK-8349] [SQL] Use expression constructors (rather than apply) in FunctionRegistry Author: Reynold Xin Closes #6806 from rxin/gs and squashes the following commits: ed1aebb [Reynold Xin] Fixed style. c7fc3e6 [Reynold Xin] [SPARK-8349][SQL] Use expression constructors (rather than apply) in FunctionRegistry --- .../spark/sql/catalyst/analysis/FunctionRegistry.scala | 18 +++++++----------- .../spark/sql/catalyst/expressions/Expression.scala | 3 +-- .../apache/spark/sql/catalyst/expressions/random.scala | 12 ++++-------- .../sql/catalyst/expressions/stringOperations.scala | 11 +++++------ .../org/apache/spark/sql/catalyst/trees/TreeNode.scala | 10 +++++----- 5 files changed, 22 insertions(+), 32 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 45bcbf73fa..04e306da23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -158,27 +158,23 @@ object FunctionRegistry { /** See usage above. */ private def expression[T <: Expression](name: String) (implicit tag: ClassTag[T]): (String, FunctionBuilder) = { - // Use the companion class to find apply methods. - val objectClass = Class.forName(tag.runtimeClass.getName + "$") - val companionObj = objectClass.getDeclaredField("MODULE$").get(null) - - // See if we can find an apply that accepts Seq[Expression] - val varargApply = Try(objectClass.getDeclaredMethod("apply", classOf[Seq[_]])).toOption + // See if we can find a constructor that accepts Seq[Expression] + val varargCtor = Try(tag.runtimeClass.getDeclaredConstructor(classOf[Seq[_]])).toOption val builder = (expressions: Seq[Expression]) => { - if (varargApply.isDefined) { + if (varargCtor.isDefined) { // If there is an apply method that accepts Seq[Expression], use that one. - varargApply.get.invoke(companionObj, expressions).asInstanceOf[Expression] + varargCtor.get.newInstance(expressions).asInstanceOf[Expression] } else { - // Otherwise, find an apply method that matches the number of arguments, and use that. + // Otherwise, find an ctor method that matches the number of arguments, and use that. val params = Seq.fill(expressions.size)(classOf[Expression]) - val f = Try(objectClass.getDeclaredMethod("apply", params : _*)) match { + val f = Try(tag.runtimeClass.getDeclaredConstructor(params : _*)) match { case Success(e) => e case Failure(e) => throw new AnalysisException(s"Invalid number of arguments for function $name") } - f.invoke(companionObj, expressions : _*).asInstanceOf[Expression] + f.newInstance(expressions : _*).asInstanceOf[Expression] } } (name, builder) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 61de34bfa4..7427ca76b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -27,8 +27,7 @@ import org.apache.spark.sql.types._ /** * If an expression wants to be exposed in the function registry (so users can call it with * "name(arguments...)", the concrete implementation must be a case class whose constructor - * arguments are all Expressions types. In addition, if it needs to support more than one - * constructor, define those constructors explicitly as apply methods in the companion object. + * arguments are all Expressions types. * * See [[Substring]] for an example. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala index 7e8033307e..cc34467391 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala @@ -49,12 +49,10 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable { /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ case class Rand(seed: Long) extends RDG(seed) { override def eval(input: InternalRow): Double = rng.nextDouble() -} -object Rand { - def apply(): Rand = apply(Utils.random.nextLong()) + def this() = this(Utils.random.nextLong()) - def apply(seed: Expression): Rand = apply(seed match { + def this(seed: Expression) = this(seed match { case IntegerLiteral(s) => s case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") }) @@ -63,12 +61,10 @@ object Rand { /** Generate a random column with i.i.d. gaussian random distribution. */ case class Randn(seed: Long) extends RDG(seed) { override def eval(input: InternalRow): Double = rng.nextGaussian() -} -object Randn { - def apply(): Randn = apply(Utils.random.nextLong()) + def this() = this(Utils.random.nextLong()) - def apply(seed: Expression): Randn = apply(seed match { + def this(seed: Expression) = this(seed match { case IntegerLiteral(s) => s case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 8ca8d22bc4..315c63e63c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.regex.Pattern import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.expressions.Substring import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -225,6 +226,10 @@ case class EndsWith(left: Expression, right: Expression) case class Substring(str: Expression, pos: Expression, len: Expression) extends Expression with ExpectsInputTypes { + def this(str: Expression, pos: Expression) = { + this(str, pos, Literal(Integer.MAX_VALUE)) + } + override def foldable: Boolean = str.foldable && pos.foldable && len.foldable override def nullable: Boolean = str.nullable || pos.nullable || len.nullable @@ -290,12 +295,6 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } } -object Substring { - def apply(str: Expression, pos: Expression): Substring = { - apply(str, pos, Literal(Integer.MAX_VALUE)) - } -} - /** * A function that return the length of the given string expression. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 36d005d0e1..5964e3dc3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -344,11 +344,11 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { * @param newArgs the new product arguments. */ def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") { - val defaultCtor = - getClass.getConstructors - .find(_.getParameterTypes.size != 0) - .headOption - .getOrElse(sys.error(s"No valid constructor for $nodeName")) + val ctors = getClass.getConstructors.filter(_.getParameterTypes.size != 0) + if (ctors.isEmpty) { + sys.error(s"No valid constructor for $nodeName") + } + val defaultCtor = ctors.maxBy(_.getParameterTypes.size) try { CurrentOrigin.withOrigin(origin) { -- cgit v1.2.3