aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/dataframe.py2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala75
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala127
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala23
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala44
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala42
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala9
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala14
12 files changed, 269 insertions, 93 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index e9dd05e2d0..9615e57649 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -746,7 +746,7 @@ class DataFrame(object):
This is a variant of :func:`select` that accepts SQL expressions.
>>> df.selectExpr("age * 2", "abs(age)").collect()
- [Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)]
+ [Row((age * 2)=4, 'abs(age)=2), Row((age * 2)=10, 'abs(age)=5)]
"""
if len(expr) == 1 and isinstance(expr[0], list):
expr = expr[0]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index e85312aee7..f74c17d583 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst
import scala.language.implicitConversions
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
@@ -48,26 +49,21 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
// Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword`
// properties via reflection the class in runtime for constructing the SqlLexical object
- protected val ABS = Keyword("ABS")
protected val ALL = Keyword("ALL")
protected val AND = Keyword("AND")
protected val APPROXIMATE = Keyword("APPROXIMATE")
protected val AS = Keyword("AS")
protected val ASC = Keyword("ASC")
- protected val AVG = Keyword("AVG")
protected val BETWEEN = Keyword("BETWEEN")
protected val BY = Keyword("BY")
protected val CASE = Keyword("CASE")
protected val CAST = Keyword("CAST")
- protected val COALESCE = Keyword("COALESCE")
- protected val COUNT = Keyword("COUNT")
protected val DESC = Keyword("DESC")
protected val DISTINCT = Keyword("DISTINCT")
protected val ELSE = Keyword("ELSE")
protected val END = Keyword("END")
protected val EXCEPT = Keyword("EXCEPT")
protected val FALSE = Keyword("FALSE")
- protected val FIRST = Keyword("FIRST")
protected val FROM = Keyword("FROM")
protected val FULL = Keyword("FULL")
protected val GROUP = Keyword("GROUP")
@@ -80,13 +76,9 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
protected val INTO = Keyword("INTO")
protected val IS = Keyword("IS")
protected val JOIN = Keyword("JOIN")
- protected val LAST = Keyword("LAST")
protected val LEFT = Keyword("LEFT")
protected val LIKE = Keyword("LIKE")
protected val LIMIT = Keyword("LIMIT")
- protected val LOWER = Keyword("LOWER")
- protected val MAX = Keyword("MAX")
- protected val MIN = Keyword("MIN")
protected val NOT = Keyword("NOT")
protected val NULL = Keyword("NULL")
protected val ON = Keyword("ON")
@@ -100,15 +92,10 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
protected val RLIKE = Keyword("RLIKE")
protected val SELECT = Keyword("SELECT")
protected val SEMI = Keyword("SEMI")
- protected val SQRT = Keyword("SQRT")
- protected val SUBSTR = Keyword("SUBSTR")
- protected val SUBSTRING = Keyword("SUBSTRING")
- protected val SUM = Keyword("SUM")
protected val TABLE = Keyword("TABLE")
protected val THEN = Keyword("THEN")
protected val TRUE = Keyword("TRUE")
protected val UNION = Keyword("UNION")
- protected val UPPER = Keyword("UPPER")
protected val WHEN = Keyword("WHEN")
protected val WHERE = Keyword("WHERE")
protected val WITH = Keyword("WITH")
@@ -277,25 +264,36 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
)
protected lazy val function: Parser[Expression] =
- ( SUM ~> "(" ~> expression <~ ")" ^^ { case exp => Sum(exp) }
- | SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => SumDistinct(exp) }
- | COUNT ~ "(" ~> "*" <~ ")" ^^ { case _ => Count(Literal(1)) }
- | COUNT ~ "(" ~> expression <~ ")" ^^ { case exp => Count(exp) }
- | COUNT ~> "(" ~> DISTINCT ~> repsep(expression, ",") <~ ")" ^^
- { case exps => CountDistinct(exps) }
- | APPROXIMATE ~ COUNT ~ "(" ~ DISTINCT ~> expression <~ ")" ^^
- { case exp => ApproxCountDistinct(exp) }
- | APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^
- { case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble) }
- | FIRST ~ "(" ~> expression <~ ")" ^^ { case exp => First(exp) }
- | LAST ~ "(" ~> expression <~ ")" ^^ { case exp => Last(exp) }
- | AVG ~ "(" ~> expression <~ ")" ^^ { case exp => Average(exp) }
- | MIN ~ "(" ~> expression <~ ")" ^^ { case exp => Min(exp) }
- | MAX ~ "(" ~> expression <~ ")" ^^ { case exp => Max(exp) }
- | UPPER ~ "(" ~> expression <~ ")" ^^ { case exp => Upper(exp) }
- | LOWER ~ "(" ~> expression <~ ")" ^^ { case exp => Lower(exp) }
- | IF ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^
- { case c ~ t ~ f => If(c, t, f) }
+ ( ident <~ ("(" ~ "*" ~ ")") ^^ { case udfName =>
+ if (lexical.normalizeKeyword(udfName) == "count") {
+ Count(Literal(1))
+ } else {
+ throw new AnalysisException(s"invalid expression $udfName(*)")
+ }
+ }
+ | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^
+ { case udfName ~ exprs => UnresolvedFunction(udfName, exprs) }
+ | ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs =>
+ lexical.normalizeKeyword(udfName) match {
+ case "sum" => SumDistinct(exprs.head)
+ case "count" => CountDistinct(exprs)
+ }
+ }
+ | APPROXIMATE ~> ident ~ ("(" ~ DISTINCT ~> expression <~ ")") ^^ { case udfName ~ exp =>
+ if (lexical.normalizeKeyword(udfName) == "count") {
+ ApproxCountDistinct(exp)
+ } else {
+ throw new AnalysisException(s"invalid function approximate $udfName")
+ }
+ }
+ | APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ ident ~ "(" ~ DISTINCT ~ expression <~ ")" ^^
+ { case s ~ _ ~ udfName ~ _ ~ _ ~ exp =>
+ if (lexical.normalizeKeyword(udfName) == "count") {
+ ApproxCountDistinct(exp, s.toDouble)
+ } else {
+ throw new AnalysisException(s"invalid function approximate($floatLit) $udfName")
+ }
+ }
| CASE ~> expression.? ~ rep1(WHEN ~> expression ~ (THEN ~> expression)) ~
(ELSE ~> expression).? <~ END ^^ {
case casePart ~ altPart ~ elsePart =>
@@ -304,16 +302,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
} ++ elsePart
casePart.map(CaseKeyWhen(_, branches)).getOrElse(CaseWhen(branches))
}
- | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) <~ ")" ^^
- { case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) }
- | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^
- { case s ~ p ~ l => Substring(s, p, l) }
- | COALESCE ~ "(" ~> repsep(expression, ",") <~ ")" ^^ { case exprs => Coalesce(exprs) }
- | SQRT ~ "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) }
- | ABS ~ "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) }
- | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^
- { case udfName ~ exprs => UnresolvedFunction(udfName, exprs) }
- )
+ )
protected lazy val cast: Parser[Expression] =
CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 5883d938b6..02b10c444d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -461,7 +461,9 @@ class Analyzer(
case q: LogicalPlan =>
q transformExpressions {
case u @ UnresolvedFunction(name, children) if u.childrenResolved =>
- registry.lookupFunction(name, children)
+ withPosition(u) {
+ registry.lookupFunction(name, children)
+ }
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 0849faa9bf..406f6fad84 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -17,24 +17,27 @@
package org.apache.spark.sql.catalyst.analysis
-import org.apache.spark.sql.catalyst.CatalystConf
-import org.apache.spark.sql.catalyst.expressions.Expression
-import scala.collection.mutable
+import scala.reflect.ClassTag
+import scala.util.{Failure, Success, Try}
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.StringKeyHashMap
+
/** A catalog for looking up user defined functions, used by an [[Analyzer]]. */
trait FunctionRegistry {
- type FunctionBuilder = Seq[Expression] => Expression
def registerFunction(name: String, builder: FunctionBuilder): Unit
+ @throws[AnalysisException]("If function does not exist")
def lookupFunction(name: String, children: Seq[Expression]): Expression
-
- def conf: CatalystConf
}
trait OverrideFunctionRegistry extends FunctionRegistry {
- val functionBuilders = StringKeyHashMap[FunctionBuilder](conf.caseSensitiveAnalysis)
+ private val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive = false)
override def registerFunction(name: String, builder: FunctionBuilder): Unit = {
functionBuilders.put(name, builder)
@@ -45,16 +48,19 @@ trait OverrideFunctionRegistry extends FunctionRegistry {
}
}
-class SimpleFunctionRegistry(val conf: CatalystConf) extends FunctionRegistry {
+class SimpleFunctionRegistry extends FunctionRegistry {
- val functionBuilders = StringKeyHashMap[FunctionBuilder](conf.caseSensitiveAnalysis)
+ private val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive = false)
override def registerFunction(name: String, builder: FunctionBuilder): Unit = {
functionBuilders.put(name, builder)
}
override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
- functionBuilders(name)(children)
+ val func = functionBuilders.get(name).getOrElse {
+ throw new AnalysisException(s"undefined function $name")
+ }
+ func(children)
}
}
@@ -70,30 +76,89 @@ object EmptyFunctionRegistry extends FunctionRegistry {
override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
throw new UnsupportedOperationException
}
-
- override def conf: CatalystConf = throw new UnsupportedOperationException
}
-/**
- * Build a map with String type of key, and it also supports either key case
- * sensitive or insensitive.
- * TODO move this into util folder?
- */
-object StringKeyHashMap {
- def apply[T](caseSensitive: Boolean): StringKeyHashMap[T] = caseSensitive match {
- case false => new StringKeyHashMap[T](_.toLowerCase)
- case true => new StringKeyHashMap[T](identity)
- }
-}
-class StringKeyHashMap[T](normalizer: (String) => String) {
- private val base = new collection.mutable.HashMap[String, T]()
+object FunctionRegistry {
- def apply(key: String): T = base(normalizer(key))
+ type FunctionBuilder = Seq[Expression] => Expression
- def get(key: String): Option[T] = base.get(normalizer(key))
- def put(key: String, value: T): Option[T] = base.put(normalizer(key), value)
- def remove(key: String): Option[T] = base.remove(normalizer(key))
- def iterator: Iterator[(String, T)] = base.toIterator
+ val expressions: Map[String, FunctionBuilder] = Map(
+ // Non aggregate functions
+ expression[Abs]("abs"),
+ expression[CreateArray]("array"),
+ expression[Coalesce]("coalesce"),
+ expression[Explode]("explode"),
+ expression[Lower]("lower"),
+ expression[Substring]("substr"),
+ expression[Substring]("substring"),
+ expression[Rand]("rand"),
+ expression[Randn]("randn"),
+ expression[CreateStruct]("struct"),
+ expression[Sqrt]("sqrt"),
+ expression[Upper]("upper"),
+
+ // Math functions
+ expression[Acos]("acos"),
+ expression[Asin]("asin"),
+ expression[Atan]("atan"),
+ expression[Atan2]("atan2"),
+ expression[Cbrt]("cbrt"),
+ expression[Ceil]("ceil"),
+ expression[Cos]("cos"),
+ expression[Exp]("exp"),
+ expression[Expm1]("expm1"),
+ expression[Floor]("floor"),
+ expression[Hypot]("hypot"),
+ expression[Log]("log"),
+ expression[Log10]("log10"),
+ expression[Log1p]("log1p"),
+ expression[Pow]("pow"),
+ expression[Rint]("rint"),
+ expression[Signum]("signum"),
+ expression[Sin]("sin"),
+ expression[Sinh]("sinh"),
+ expression[Tan]("tan"),
+ expression[Tanh]("tanh"),
+ expression[ToDegrees]("todegrees"),
+ expression[ToRadians]("toradians"),
+
+ // aggregate functions
+ expression[Average]("avg"),
+ expression[Count]("count"),
+ expression[First]("first"),
+ expression[Last]("last"),
+ expression[Max]("max"),
+ expression[Min]("min"),
+ expression[Sum]("sum")
+ )
+
+ /** See usage above. */
+ private def expression[T <: Expression](name: String)
+ (implicit tag: ClassTag[T]): (String, FunctionBuilder) = {
+ // Use the companion class to find apply methods.
+ val objectClass = Class.forName(tag.runtimeClass.getName + "$")
+ val companionObj = objectClass.getDeclaredField("MODULE$").get(null)
+
+ // See if we can find an apply that accepts Seq[Expression]
+ val varargApply = Try(objectClass.getDeclaredMethod("apply", classOf[Seq[_]])).toOption
+
+ val builder = (expressions: Seq[Expression]) => {
+ if (varargApply.isDefined) {
+ // If there is an apply method that accepts Seq[Expression], use that one.
+ varargApply.get.invoke(companionObj, expressions).asInstanceOf[Expression]
+ } else {
+ // Otherwise, find an apply method that matches the number of arguments, and use that.
+ val params = Seq.fill(expressions.size)(classOf[Expression])
+ val f = Try(objectClass.getDeclaredMethod("apply", params : _*)) match {
+ case Success(e) =>
+ e
+ case Failure(e) =>
+ throw new AnalysisException(s"Invalid number of arguments for function $name")
+ }
+ f.invoke(companionObj, expressions : _*).asInstanceOf[Expression]
+ }
+ }
+ (name, builder)
+ }
}
-
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index a9a9c0cfb7..f2ed1f0929 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -23,6 +23,15 @@ import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types._
+
+/**
+ * For Catalyst to work correctly, concrete implementations of [[Expression]]s must be case classes
+ * whose constructor arguments are all Expressions types. In addition, if we want to support more
+ * than one constructor, define those constructors explicitly as apply methods in the companion
+ * object.
+ *
+ * See [[Substring]] for an example.
+ */
abstract class Expression extends TreeNode[Expression] {
self: Product =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
index b2647124c4..6e4e9cb1be 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.TaskContext
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.types.{DataType, DoubleType}
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
@@ -46,11 +47,29 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable {
}
/** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */
-case class Rand(seed: Long = Utils.random.nextLong()) extends RDG(seed) {
+case class Rand(seed: Long) extends RDG(seed) {
override def eval(input: Row): Double = rng.nextDouble()
}
+object Rand {
+ def apply(): Rand = apply(Utils.random.nextLong())
+
+ def apply(seed: Expression): Rand = apply(seed match {
+ case IntegerLiteral(s) => s
+ case _ => throw new AnalysisException("Input argument to rand must be an integer literal.")
+ })
+}
+
/** Generate a random column with i.i.d. gaussian random distribution. */
-case class Randn(seed: Long = Utils.random.nextLong()) extends RDG(seed) {
+case class Randn(seed: Long) extends RDG(seed) {
override def eval(input: Row): Double = rng.nextGaussian()
}
+
+object Randn {
+ def apply(): Randn = apply(Utils.random.nextLong())
+
+ def apply(seed: Expression): Randn = apply(seed match {
+ case IntegerLiteral(s) => s
+ case _ => throw new AnalysisException("Input argument to rand must be an integer literal.")
+ })
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index aae122a981..856f56488c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -227,6 +227,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
override def foldable: Boolean = str.foldable && pos.foldable && len.foldable
override def nullable: Boolean = str.nullable || pos.nullable || len.nullable
+
override def dataType: DataType = {
if (!resolved) {
throw new UnresolvedException(this, s"Cannot resolve since $children are not resolved")
@@ -287,3 +288,9 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
case _ => s"SUBSTR($str, $pos, $len)"
}
}
+
+object Substring {
+ def apply(str: Expression, pos: Expression): Substring = {
+ apply(str, pos, Literal(Integer.MAX_VALUE))
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala
new file mode 100644
index 0000000000..191d5e6399
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+/**
+ * Build a map with String type of key, and it also supports either key case
+ * sensitive or insensitive.
+ */
+object StringKeyHashMap {
+ def apply[T](caseSensitive: Boolean): StringKeyHashMap[T] = caseSensitive match {
+ case false => new StringKeyHashMap[T](_.toLowerCase)
+ case true => new StringKeyHashMap[T](identity)
+ }
+}
+
+
+class StringKeyHashMap[T](normalizer: (String) => String) {
+ private val base = new collection.mutable.HashMap[String, T]()
+
+ def apply(key: String): T = base(normalizer(key))
+
+ def get(key: String): Option[T] = base.get(normalizer(key))
+
+ def put(key: String, value: T): Option[T] = base.put(normalizer(key), value)
+
+ def remove(key: String): Option[T] = base.remove(normalizer(key))
+
+ def iterator: Iterator[(String, T)] = base.toIterator
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index ddb54025ba..8cad3885b7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -120,7 +120,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
// TODO how to handle the temp function per user session?
@transient
- protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry(conf)
+ protected[sql] lazy val functionRegistry: FunctionRegistry = {
+ val fr = new SimpleFunctionRegistry
+ FunctionRegistry.expressions.foreach { case (name, func) => fr.registerFunction(name, func) }
+ fr
+ }
@transient
protected[sql] lazy val analyzer: Analyzer =
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 064c040d2b..703a34c47e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -25,6 +25,48 @@ class UDFSuite extends QueryTest {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
+ test("built-in fixed arity expressions") {
+ val df = ctx.emptyDataFrame
+ df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)")
+ }
+
+ test("built-in vararg expressions") {
+ val df = Seq((1, 2)).toDF("a", "b")
+ df.selectExpr("array(a, b)")
+ df.selectExpr("struct(a, b)")
+ }
+
+ test("built-in expressions with multiple constructors") {
+ val df = Seq(("abcd", 2)).toDF("a", "b")
+ df.selectExpr("substr(a, 2)", "substr(a, 2, 3)").collect()
+ }
+
+ test("count") {
+ val df = Seq(("abcd", 2)).toDF("a", "b")
+ df.selectExpr("count(a)")
+ }
+
+ test("count distinct") {
+ val df = Seq(("abcd", 2)).toDF("a", "b")
+ df.selectExpr("count(distinct a)")
+ }
+
+ test("error reporting for incorrect number of arguments") {
+ val df = ctx.emptyDataFrame
+ val e = intercept[AnalysisException] {
+ df.selectExpr("substr('abcd', 2, 3, 4)")
+ }
+ assert(e.getMessage.contains("arguments"))
+ }
+
+ test("error reporting for undefined functions") {
+ val df = ctx.emptyDataFrame
+ val e = intercept[AnalysisException] {
+ df.selectExpr("a_function_that_does_not_exist()")
+ }
+ assert(e.getMessage.contains("undefined function"))
+ }
+
test("Simple UDF") {
ctx.udf.register("strLenScala", (_: String).length)
assert(ctx.sql("SELECT strLenScala('test')").head().getInt(0) === 4)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index b8f294c262..3b8cafb4a6 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -39,13 +39,12 @@ import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable}
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateSubQueries, OverrideCatalog, OverrideFunctionRegistry}
+import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand}
import org.apache.spark.sql.hive.client._
import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand}
import org.apache.spark.sql.sources.DataSourceStrategy
-import org.apache.spark.sql.catalyst.CatalystConf
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -374,10 +373,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
// Note that HiveUDFs will be overridden by functions registered in this context.
@transient
- override protected[sql] lazy val functionRegistry =
- new HiveFunctionRegistry with OverrideFunctionRegistry {
- override def conf: CatalystConf = currentSession().conf
- }
+ override protected[sql] lazy val functionRegistry: FunctionRegistry =
+ new HiveFunctionRegistry with OverrideFunctionRegistry
/* An analyzer that uses the Hive metastore. */
@transient
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index 01f47352b2..6e6ac987b6 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -17,11 +17,8 @@
package org.apache.spark.sql.hive
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper
-import org.apache.spark.sql.AnalysisException
-
import scala.collection.mutable.ArrayBuffer
+import scala.collection.JavaConversions._
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ConstantObjectInspector}
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions
@@ -30,8 +27,11 @@ import org.apache.hadoop.hive.ql.exec._
import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType}
import org.apache.hadoop.hive.ql.udf.generic._
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper
import org.apache.spark.Logging
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
@@ -40,20 +40,18 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.hive.HiveShim._
import org.apache.spark.sql.types._
-/* Implicit conversions */
-import scala.collection.JavaConversions._
private[hive] abstract class HiveFunctionRegistry
extends analysis.FunctionRegistry with HiveInspectors {
def getFunctionInfo(name: String): FunctionInfo = FunctionRegistry.getFunctionInfo(name)
- def lookupFunction(name: String, children: Seq[Expression]): Expression = {
+ override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
// We only look it up to see if it exists, but do not include it in the HiveUDF since it is
// not always serializable.
val functionInfo: FunctionInfo =
Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse(
- sys.error(s"Couldn't find function $name"))
+ throw new AnalysisException(s"undefined function $name"))
val functionClassName = functionInfo.getFunctionClass.getName