aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorCheng Hao <hao.cheng@intel.com>2015-02-03 12:12:26 -0800
committerMichael Armbrust <michael@databricks.com>2015-02-03 12:12:26 -0800
commitca7a6cdff004eb4605fd223e127b4a46a0a214e7 (patch)
tree31e7c93b147b264557e45eea777dce63e9343ff7 /sql
parent0c20ce69fb4bcb1cec5313a9d072826c5588cbbc (diff)
downloadspark-ca7a6cdff004eb4605fd223e127b4a46a0a214e7.tar.gz
spark-ca7a6cdff004eb4605fd223e127b4a46a0a214e7.tar.bz2
spark-ca7a6cdff004eb4605fd223e127b4a46a0a214e7.zip
[SPARK-5550] [SQL] Support the case insensitive for UDF
SQL in HiveContext, should be case insensitive, however, the following query will fail. ```scala udf.register("random0", () => { Math.random()}) assert(sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) ``` Author: Cheng Hao <hao.cheng@intel.com> Closes #4326 from chenghao-intel/udf_case_sensitive and squashes the following commits: 485cf66 [Cheng Hao] Support the case insensitive for UDF
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala36
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala36
4 files changed, 72 insertions, 6 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 760c49fbca..9f334f6d42 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -27,23 +27,25 @@ trait FunctionRegistry {
def registerFunction(name: String, builder: FunctionBuilder): Unit
def lookupFunction(name: String, children: Seq[Expression]): Expression
+
+ def caseSensitive: Boolean
}
trait OverrideFunctionRegistry extends FunctionRegistry {
- val functionBuilders = new mutable.HashMap[String, FunctionBuilder]()
+ val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive)
def registerFunction(name: String, builder: FunctionBuilder) = {
functionBuilders.put(name, builder)
}
abstract override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
- functionBuilders.get(name).map(_(children)).getOrElse(super.lookupFunction(name,children))
+ functionBuilders.get(name).map(_(children)).getOrElse(super.lookupFunction(name, children))
}
}
-class SimpleFunctionRegistry extends FunctionRegistry {
- val functionBuilders = new mutable.HashMap[String, FunctionBuilder]()
+class SimpleFunctionRegistry(val caseSensitive: Boolean) extends FunctionRegistry {
+ val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive)
def registerFunction(name: String, builder: FunctionBuilder) = {
functionBuilders.put(name, builder)
@@ -64,4 +66,30 @@ object EmptyFunctionRegistry extends FunctionRegistry {
def lookupFunction(name: String, children: Seq[Expression]): Expression = {
throw new UnsupportedOperationException
}
+
+ def caseSensitive: Boolean = ???
+}
+
+/**
+ * Build a map with String type of key, and it also supports either key case
+ * sensitive or insensitive.
+ * TODO move this into util folder?
+ */
+object StringKeyHashMap {
+ def apply[T](caseSensitive: Boolean) = caseSensitive match {
+ case false => new StringKeyHashMap[T](_.toLowerCase)
+ case true => new StringKeyHashMap[T](identity)
+ }
+}
+
+class StringKeyHashMap[T](normalizer: (String) => String) {
+ private val base = new collection.mutable.HashMap[String, T]()
+
+ def apply(key: String): T = base(normalizer(key))
+
+ def get(key: String): Option[T] = base.get(normalizer(key))
+ def put(key: String, value: T): Option[T] = base.put(normalizer(key), value)
+ def remove(key: String): Option[T] = base.remove(normalizer(key))
+ def iterator: Iterator[(String, T)] = base.toIterator
}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index a741d0031d..2697e780c0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -87,7 +87,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
protected[sql] lazy val catalog: Catalog = new SimpleCatalog(true)
@transient
- protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry
+ protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry(true)
@transient
protected[sql] lazy val analyzer: Analyzer =
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index f6d9027f90..50f266a4bc 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -311,7 +311,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
// Note that HiveUDFs will be overridden by functions registered in this context.
@transient
override protected[sql] lazy val functionRegistry =
- new HiveFunctionRegistry with OverrideFunctionRegistry
+ new HiveFunctionRegistry with OverrideFunctionRegistry {
+ def caseSensitive = false
+ }
/* An analyzer that uses the Hive metastore. */
@transient
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
new file mode 100644
index 0000000000..85b6bc93d7
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+/* Implicits */
+
+import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.hive.test.TestHive._
+
+case class FunctionResult(f1: String, f2: String)
+
+class UDFSuite extends QueryTest {
+ test("UDF case insensitive") {
+ udf.register("random0", () => { Math.random()})
+ udf.register("RANDOM1", () => { Math.random()})
+ udf.register("strlenScala", (_: String).length + (_:Int))
+ assert(sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0)
+ assert(sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0)
+ assert(sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5)
+ }
+}