aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorpetermaxlee <petermaxlee@gmail.com>2016-07-13 08:05:20 +0800
committerWenchen Fan <wenchen@databricks.com>2016-07-13 08:05:20 +0800
commit56bd399a86c4e92be412d151200cb5e4a5f6a48a (patch)
treeef355d9f472cc20015240478829e0ab1d2c4f4d1 /sql
parent7f968867ff61c6b1a007874ee7e3a7421d94d373 (diff)
downloadspark-56bd399a86c4e92be412d151200cb5e4a5f6a48a.tar.gz
spark-56bd399a86c4e92be412d151200cb5e4a5f6a48a.tar.bz2
spark-56bd399a86c4e92be412d151200cb5e4a5f6a48a.zip
[SPARK-16284][SQL] Implement reflect SQL function
## What changes were proposed in this pull request? This patch implements reflect SQL function, which can be used to invoke a Java method in SQL. Slightly different from Hive, this implementation requires the class name and the method name to be literals. This implementation also supports only a smaller number of data types, and requires the function to be static, as suggested by rxin in #13969. java_method is an alias for reflect, so this should also resolve SPARK-16277. ## How was this patch tested? Added expression unit tests and an end-to-end test. Author: petermaxlee <petermaxlee@gmail.com> Closes #14138 from petermaxlee/reflect-static.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala164
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala102
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala38
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala7
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala23
6 files changed, 311 insertions, 25 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 54568b7445..65a90d8099 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -368,6 +368,8 @@ object FunctionRegistry {
expression[InputFileName]("input_file_name"),
expression[MonotonicallyIncreasingID]("monotonically_increasing_id"),
expression[CurrentDatabase]("current_database"),
+ expression[CallMethodViaReflection]("reflect"),
+ expression[CallMethodViaReflection]("java_method"),
// grouping sets
expression[Cube]("cube"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala
new file mode 100644
index 0000000000..fe24c0489f
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import java.lang.reflect.{Method, Modifier}
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.Utils
+
+/**
+ * An expression that invokes a method on a class via reflection.
+ *
+ * For now, only types defined in `Reflect.typeMapping` are supported (basically primitives
+ * and string) as input types, and the output is turned automatically to a string.
+ *
+ * Note that unlike Hive's reflect function, this expression calls only static methods
+ * (i.e. does not support calling non-static methods).
+ *
+ * We should also look into how to consolidate this expression with
+ * [[org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke]] in the future.
+ *
+ * @param children the first element should be a literal string for the class name,
+ * and the second element should be a literal string for the method name,
+ * and the remaining are input arguments to the Java method.
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(class,method[,arg1[,arg2..]]) calls method with reflection",
+ extended = "> SELECT _FUNC_('java.util.UUID', 'randomUUID');\n c33fb387-8500-4bfa-81d2-6e0e3e930df2")
+// scalastyle:on line.size.limit
+case class CallMethodViaReflection(children: Seq[Expression])
+ extends Expression with CodegenFallback {
+
+ override def prettyName: String = "reflect"
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (children.size < 2) {
+ TypeCheckFailure("requires at least two arguments")
+ } else if (!children.take(2).forall(e => e.dataType == StringType && e.foldable)) {
+ // The first two arguments must be string type.
+ TypeCheckFailure("first two arguments should be string literals")
+ } else if (!classExists) {
+ TypeCheckFailure(s"class $className not found")
+ } else if (method == null) {
+ TypeCheckFailure(s"cannot find a static method that matches the argument types in $className")
+ } else {
+ TypeCheckSuccess
+ }
+ }
+
+ override def deterministic: Boolean = false
+ override def nullable: Boolean = true
+ override val dataType: DataType = StringType
+
+ override def eval(input: InternalRow): Any = {
+ var i = 0
+ while (i < argExprs.length) {
+ buffer(i) = argExprs(i).eval(input).asInstanceOf[Object]
+ // Convert if necessary. Based on the types defined in typeMapping, string is the only
+ // type that needs conversion. If we support timestamps, dates, decimals, arrays, or maps
+ // in the future, proper conversion needs to happen here too.
+ if (buffer(i).isInstanceOf[UTF8String]) {
+ buffer(i) = buffer(i).toString
+ }
+ i += 1
+ }
+ val ret = method.invoke(null, buffer : _*)
+ UTF8String.fromString(String.valueOf(ret))
+ }
+
+ @transient private lazy val argExprs: Array[Expression] = children.drop(2).toArray
+
+ /** Name of the class -- this has to be called after we verify children has at least two exprs. */
+ @transient private lazy val className = children(0).eval().asInstanceOf[UTF8String].toString
+
+ /** True if the class exists and can be loaded. */
+ @transient private lazy val classExists = CallMethodViaReflection.classExists(className)
+
+ /** The reflection method. */
+ @transient lazy val method: Method = {
+ val methodName = children(1).eval(null).asInstanceOf[UTF8String].toString
+ CallMethodViaReflection.findMethod(className, methodName, argExprs.map(_.dataType)).orNull
+ }
+
+ /** A temporary buffer used to hold intermediate results returned by children. */
+ @transient private lazy val buffer = new Array[Object](argExprs.length)
+}
+
+object CallMethodViaReflection {
+ /** Mapping from Spark's type to acceptable JVM types. */
+ val typeMapping = Map[DataType, Seq[Class[_]]](
+ BooleanType -> Seq(classOf[java.lang.Boolean], classOf[Boolean]),
+ ByteType -> Seq(classOf[java.lang.Byte], classOf[Byte]),
+ ShortType -> Seq(classOf[java.lang.Short], classOf[Short]),
+ IntegerType -> Seq(classOf[java.lang.Integer], classOf[Int]),
+ LongType -> Seq(classOf[java.lang.Long], classOf[Long]),
+ FloatType -> Seq(classOf[java.lang.Float], classOf[Float]),
+ DoubleType -> Seq(classOf[java.lang.Double], classOf[Double]),
+ StringType -> Seq(classOf[String])
+ )
+
+ /**
+ * Returns true if the class can be found and loaded.
+ */
+ private def classExists(className: String): Boolean = {
+ try {
+ Utils.classForName(className)
+ true
+ } catch {
+ case e: ClassNotFoundException => false
+ }
+ }
+
+ /**
+ * Finds a Java static method using reflection that matches the given argument types,
+ * and whose return type is string.
+ *
+ * The types sequence must be the valid types defined in [[typeMapping]].
+ *
+ * This is made public for unit testing.
+ */
+ def findMethod(className: String, methodName: String, argTypes: Seq[DataType]): Option[Method] = {
+ val clazz: Class[_] = Utils.classForName(className)
+ clazz.getMethods.find { method =>
+ val candidateTypes = method.getParameterTypes
+ if (method.getName != methodName) {
+ // Name must match
+ false
+ } else if (!Modifier.isStatic(method.getModifiers)) {
+ // Method must be static
+ false
+ } else if (candidateTypes.length != argTypes.length) {
+ // Argument length must match
+ false
+ } else {
+ // Argument type must match. That is, either the method's argument type matches one of the
+ // acceptable types defined in typeMapping, or it is a super type of the acceptable types.
+ candidateTypes.zip(argTypes).forall { case (candidateType, argType) =>
+ typeMapping(argType).exists(candidateType.isAssignableFrom)
+ }
+ }
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala
new file mode 100644
index 0000000000..43367c7e14
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
+import org.apache.spark.sql.types.{IntegerType, StringType}
+
+/** A static class for testing purpose. */
+object ReflectStaticClass {
+ def method1(): String = "m1"
+ def method2(v1: Int): String = "m" + v1
+ def method3(v1: java.lang.Integer): String = "m" + v1
+ def method4(v1: Int, v2: String): String = "m" + v1 + v2
+}
+
+/** A non-static class for testing purpose. */
+class ReflectDynamicClass {
+ def method1(): String = "m1"
+}
+
+/**
+ * Test suite for [[CallMethodViaReflection]] and its companion object.
+ */
+class CallMethodViaReflectionSuite extends SparkFunSuite with ExpressionEvalHelper {
+
+ import CallMethodViaReflection._
+
+ // Get rid of the $ so we are getting the companion object's name.
+ private val staticClassName = ReflectStaticClass.getClass.getName.stripSuffix("$")
+ private val dynamicClassName = classOf[ReflectDynamicClass].getName
+
+ test("findMethod via reflection for static methods") {
+ assert(findMethod(staticClassName, "method1", Seq.empty).exists(_.getName == "method1"))
+ assert(findMethod(staticClassName, "method2", Seq(IntegerType)).isDefined)
+ assert(findMethod(staticClassName, "method3", Seq(IntegerType)).isDefined)
+ assert(findMethod(staticClassName, "method4", Seq(IntegerType, StringType)).isDefined)
+ }
+
+ test("findMethod for a JDK library") {
+ assert(findMethod(classOf[java.util.UUID].getName, "randomUUID", Seq.empty).isDefined)
+ }
+
+ test("class not found") {
+ val ret = createExpr("some-random-class", "method").checkInputDataTypes()
+ assert(ret.isFailure)
+ val errorMsg = ret.asInstanceOf[TypeCheckFailure].message
+ assert(errorMsg.contains("not found") && errorMsg.contains("class"))
+ }
+
+ test("method not found because name does not match") {
+ val ret = createExpr(staticClassName, "notfoundmethod").checkInputDataTypes()
+ assert(ret.isFailure)
+ val errorMsg = ret.asInstanceOf[TypeCheckFailure].message
+ assert(errorMsg.contains("cannot find a static method"))
+ }
+
+ test("method not found because there is no static method") {
+ val ret = createExpr(dynamicClassName, "method1").checkInputDataTypes()
+ assert(ret.isFailure)
+ val errorMsg = ret.asInstanceOf[TypeCheckFailure].message
+ assert(errorMsg.contains("cannot find a static method"))
+ }
+
+ test("input type checking") {
+ assert(CallMethodViaReflection(Seq.empty).checkInputDataTypes().isFailure)
+ assert(CallMethodViaReflection(Seq(Literal(staticClassName))).checkInputDataTypes().isFailure)
+ assert(CallMethodViaReflection(
+ Seq(Literal(staticClassName), Literal(1))).checkInputDataTypes().isFailure)
+ assert(createExpr(staticClassName, "method1").checkInputDataTypes().isSuccess)
+ }
+
+ test("invoking methods using acceptable types") {
+ checkEvaluation(createExpr(staticClassName, "method1"), "m1")
+ checkEvaluation(createExpr(staticClassName, "method2", 2), "m2")
+ checkEvaluation(createExpr(staticClassName, "method3", 3), "m3")
+ checkEvaluation(createExpr(staticClassName, "method4", 4, "four"), "m4four")
+ }
+
+ private def createExpr(className: String, methodName: String, args: Any*) = {
+ CallMethodViaReflection(
+ Literal.create(className, StringType) +:
+ Literal.create(methodName, StringType) +:
+ args.map(Literal.apply)
+ )
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala
new file mode 100644
index 0000000000..a5b08f7177
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.test.SharedSQLContext
+
+class MiscFunctionsSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
+
+ test("reflect and java_method") {
+ val df = Seq((1, "one")).toDF("a", "b")
+ val className = ReflectClass.getClass.getName.stripSuffix("$")
+ checkAnswer(
+ df.selectExpr(
+ s"reflect('$className', 'method1', a, b)",
+ s"java_method('$className', 'method1', a, b)"),
+ Row("m1one", "m1one"))
+ }
+}
+
+object ReflectClass {
+ def method1(v1: Int, v2: String): String = "m" + v1 + v2
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index 6f36abc4db..b8a75850b1 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -235,7 +235,10 @@ private[sql] class HiveSessionCatalog(
// parse_url_tuple, posexplode, reflect2,
// str_to_map, windowingtablefunction.
private val hiveFunctions = Seq(
- "hash", "java_method", "histogram_numeric",
- "percentile", "percentile_approx", "reflect", "str_to_map"
+ "hash",
+ "histogram_numeric",
+ "percentile",
+ "percentile_approx",
+ "str_to_map"
)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index a43f0d0d7e..961d95c268 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -996,29 +996,6 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
checkAnswer(sql("SELECT CAST('775983671874188101' as BIGINT)"), Row(775983671874188101L))
}
- // `Math.exp(1.0)` has different result for different jdk version, so not use createQueryTest
- test("udf_java_method") {
- checkAnswer(sql(
- """
- |SELECT java_method("java.lang.String", "valueOf", 1),
- | java_method("java.lang.String", "isEmpty"),
- | java_method("java.lang.Math", "max", 2, 3),
- | java_method("java.lang.Math", "min", 2, 3),
- | java_method("java.lang.Math", "round", 2.5D),
- | java_method("java.lang.Math", "exp", 1.0D),
- | java_method("java.lang.Math", "floor", 1.9D)
- |FROM src tablesample (1 rows)
- """.stripMargin),
- Row(
- "1",
- "true",
- java.lang.Math.max(2, 3).toString,
- java.lang.Math.min(2, 3).toString,
- java.lang.Math.round(2.5).toString,
- java.lang.Math.exp(1.0).toString,
- java.lang.Math.floor(1.9).toString))
- }
-
test("dynamic partition value test") {
try {
sql("set hive.exec.dynamic.partition.mode=nonstrict")