diff options
author | Michael Armbrust <michael@databricks.com> | 2014-08-02 16:33:48 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-08-02 16:33:48 -0700 |
commit | 158ad0bba9382fd494b4789b5628a9cec00cfa19 (patch) | |
tree | 761acc1c8694a167043fc8f45bfa49447d6c1f2d /sql/hive/src | |
parent | 4c477117bb1ffef463776c86f925d35036f96b7a (diff) | |
download | spark-158ad0bba9382fd494b4789b5628a9cec00cfa19.tar.gz spark-158ad0bba9382fd494b4789b5628a9cec00cfa19.tar.bz2 spark-158ad0bba9382fd494b4789b5628a9cec00cfa19.zip |
[SPARK-2097][SQL] UDF Support
This patch adds the ability to register lambda functions written in Python, Java or Scala as UDFs for use in SQL or HiveQL.
Scala:
```scala
registerFunction("strLenScala", (_: String).length)
sql("SELECT strLenScala('test')")
```
Python:
```python
sqlCtx.registerFunction("strLenPython", lambda x: len(x), IntegerType())
sqlCtx.sql("SELECT strLenPython('test')")
```
Java:
```java
sqlContext.registerFunction("stringLengthJava", new UDF1<String, Integer>() {
Override
public Integer call(String str) throws Exception {
return str.length();
}
}, DataType.IntegerType);
sqlContext.sql("SELECT stringLengthJava('test')");
```
Author: Michael Armbrust <michael@databricks.com>
Closes #1063 from marmbrus/udfs and squashes the following commits:
9eda0fe [Michael Armbrust] newline
747c05e [Michael Armbrust] Add some scala UDF tests.
d92727d [Michael Armbrust] Merge remote-tracking branch 'apache/master' into udfs
005d684 [Michael Armbrust] Fix naming and formatting.
d14dac8 [Michael Armbrust] Fix last line of autogened java files.
8135c48 [Michael Armbrust] Move UDF unit tests to pyspark.
40b0ffd [Michael Armbrust] Merge remote-tracking branch 'apache/master' into udfs
6a36890 [Michael Armbrust] Switch logging so that SQLContext can be serializable.
7a83101 [Michael Armbrust] Drop toString
795fd15 [Michael Armbrust] Try to avoid capturing SQLContext.
e54fb45 [Michael Armbrust] Docs and tests.
437cbe3 [Michael Armbrust] Update use of dataTypes, fix some python tests, address review comments.
01517d6 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into udfs
8e6c932 [Michael Armbrust] WIP
3f96a52 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into udfs
6237c8d [Michael Armbrust] WIP
2766f0b [Michael Armbrust] Move udfs support to SQL from hive. Add support for Java UDFs.
0f7d50c [Michael Armbrust] Draft of native Spark SQL UDFs for Scala and Python.
Diffstat (limited to 'sql/hive/src')
4 files changed, 16 insertions, 11 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 2c7270d9f8..3c70b3f092 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -23,7 +23,7 @@ import java.util.{ArrayList => JArrayList} import scala.collection.JavaConversions._ import scala.language.implicitConversions -import scala.reflect.runtime.universe.TypeTag +import scala.reflect.runtime.universe.{TypeTag, typeTag} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver @@ -35,8 +35,9 @@ import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{Analyzer, OverrideCatalog} +import org.apache.spark.sql.catalyst.analysis.{OverrideFunctionRegistry, Analyzer, OverrideCatalog} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.ExtractPythonUdfs import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.execution.{Command => PhysicalCommand} import org.apache.spark.sql.hive.execution.DescribeHiveTableCommand @@ -155,10 +156,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } } + // Note that HiveUDFs will be overridden by functions registered in this context. + override protected[sql] lazy val functionRegistry = + new HiveFunctionRegistry with OverrideFunctionRegistry + /* An analyzer that uses the Hive metastore. */ @transient override protected[sql] lazy val analyzer = - new Analyzer(catalog, HiveFunctionRegistry, caseSensitive = false) + new Analyzer(catalog, functionRegistry, caseSensitive = false) /** * Runs the specified SQL query using Hive. @@ -250,7 +255,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { protected[sql] abstract class QueryExecution extends super.QueryExecution { // TODO: Create mixin for the analyzer instead of overriding things here. override lazy val optimizedPlan = - optimizer(catalog.PreInsertionCasts(catalog.CreateTables(analyzed))) + optimizer(ExtractPythonUdfs(catalog.PreInsertionCasts(catalog.CreateTables(analyzed)))) override lazy val toRdd: RDD[Row] = executedPlan.execute().map(_.copy()) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index 728452a25a..c605e8adcf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -297,8 +297,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { def reset() { try { // HACK: Hive is too noisy by default. - org.apache.log4j.LogManager.getCurrentLoggers.foreach { logger => - logger.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) + org.apache.log4j.LogManager.getCurrentLoggers.foreach { log => + log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) } // It is important that we RESET first as broken hooks that might have been set could break diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index d181921269..179aac5cbd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -34,7 +34,8 @@ import org.apache.spark.util.Utils.getContextOrSparkClassLoader /* Implicit conversions */ import scala.collection.JavaConversions._ -private[hive] object HiveFunctionRegistry extends analysis.FunctionRegistry with HiveInspectors { +private[hive] abstract class HiveFunctionRegistry + extends analysis.FunctionRegistry with HiveInspectors { def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name) @@ -92,9 +93,8 @@ private[hive] abstract class HiveUdf extends Expression with Logging with HiveFu } private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[Expression]) - extends HiveUdf { + extends HiveUdf with HiveInspectors { - import org.apache.spark.sql.hive.HiveFunctionRegistry._ type UDFType = UDF @transient diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala index 11d8b1f0a3..95921c3d7a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -51,9 +51,9 @@ class QueryTest extends FunSuite { fail( s""" |Exception thrown while executing query: - |${rdd.logicalPlan} + |${rdd.queryExecution} |== Exception == - |$e + |${stackTraceToString(e)} """.stripMargin) } |