From 56bd399a86c4e92be412d151200cb5e4a5f6a48a Mon Sep 17 00:00:00 2001 From: petermaxlee Date: Wed, 13 Jul 2016 08:05:20 +0800 Subject: [SPARK-16284][SQL] Implement reflect SQL function ## What changes were proposed in this pull request? This patch implements reflect SQL function, which can be used to invoke a Java method in SQL. Slightly different from Hive, this implementation requires the class name and the method name to be literals. This implementation also supports only a smaller number of data types, and requires the function to be static, as suggested by rxin in #13969. java_method is an alias for reflect, so this should also resolve SPARK-16277. ## How was this patch tested? Added expression unit tests and an end-to-end test. Author: petermaxlee Closes #14138 from petermaxlee/reflect-static. --- .../sql/catalyst/analysis/FunctionRegistry.scala | 2 + .../expressions/CallMethodViaReflection.scala | 164 +++++++++++++++++++++ .../expressions/CallMethodViaReflectionSuite.scala | 102 +++++++++++++ .../org/apache/spark/sql/MiscFunctionsSuite.scala | 38 +++++ .../apache/spark/sql/hive/HiveSessionCatalog.scala | 7 +- .../spark/sql/hive/execution/SQLQuerySuite.scala | 23 --- 6 files changed, 311 insertions(+), 25 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 54568b7445..65a90d8099 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -368,6 +368,8 @@ object FunctionRegistry { expression[InputFileName]("input_file_name"), expression[MonotonicallyIncreasingID]("monotonically_increasing_id"), expression[CurrentDatabase]("current_database"), + expression[CallMethodViaReflection]("reflect"), + expression[CallMethodViaReflection]("java_method"), // grouping sets expression[Cube]("cube"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala new file mode 100644 index 0000000000..fe24c0489f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.lang.reflect.{Method, Modifier} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + +/** + * An expression that invokes a method on a class via reflection. + * + * For now, only types defined in `Reflect.typeMapping` are supported (basically primitives + * and string) as input types, and the output is turned automatically to a string. + * + * Note that unlike Hive's reflect function, this expression calls only static methods + * (i.e. does not support calling non-static methods). + * + * We should also look into how to consolidate this expression with + * [[org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke]] in the future. + * + * @param children the first element should be a literal string for the class name, + * and the second element should be a literal string for the method name, + * and the remaining are input arguments to the Java method. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(class,method[,arg1[,arg2..]]) calls method with reflection", + extended = "> SELECT _FUNC_('java.util.UUID', 'randomUUID');\n c33fb387-8500-4bfa-81d2-6e0e3e930df2") +// scalastyle:on line.size.limit +case class CallMethodViaReflection(children: Seq[Expression]) + extends Expression with CodegenFallback { + + override def prettyName: String = "reflect" + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.size < 2) { + TypeCheckFailure("requires at least two arguments") + } else if (!children.take(2).forall(e => e.dataType == StringType && e.foldable)) { + // The first two arguments must be string type. + TypeCheckFailure("first two arguments should be string literals") + } else if (!classExists) { + TypeCheckFailure(s"class $className not found") + } else if (method == null) { + TypeCheckFailure(s"cannot find a static method that matches the argument types in $className") + } else { + TypeCheckSuccess + } + } + + override def deterministic: Boolean = false + override def nullable: Boolean = true + override val dataType: DataType = StringType + + override def eval(input: InternalRow): Any = { + var i = 0 + while (i < argExprs.length) { + buffer(i) = argExprs(i).eval(input).asInstanceOf[Object] + // Convert if necessary. Based on the types defined in typeMapping, string is the only + // type that needs conversion. If we support timestamps, dates, decimals, arrays, or maps + // in the future, proper conversion needs to happen here too. + if (buffer(i).isInstanceOf[UTF8String]) { + buffer(i) = buffer(i).toString + } + i += 1 + } + val ret = method.invoke(null, buffer : _*) + UTF8String.fromString(String.valueOf(ret)) + } + + @transient private lazy val argExprs: Array[Expression] = children.drop(2).toArray + + /** Name of the class -- this has to be called after we verify children has at least two exprs. */ + @transient private lazy val className = children(0).eval().asInstanceOf[UTF8String].toString + + /** True if the class exists and can be loaded. */ + @transient private lazy val classExists = CallMethodViaReflection.classExists(className) + + /** The reflection method. */ + @transient lazy val method: Method = { + val methodName = children(1).eval(null).asInstanceOf[UTF8String].toString + CallMethodViaReflection.findMethod(className, methodName, argExprs.map(_.dataType)).orNull + } + + /** A temporary buffer used to hold intermediate results returned by children. */ + @transient private lazy val buffer = new Array[Object](argExprs.length) +} + +object CallMethodViaReflection { + /** Mapping from Spark's type to acceptable JVM types. */ + val typeMapping = Map[DataType, Seq[Class[_]]]( + BooleanType -> Seq(classOf[java.lang.Boolean], classOf[Boolean]), + ByteType -> Seq(classOf[java.lang.Byte], classOf[Byte]), + ShortType -> Seq(classOf[java.lang.Short], classOf[Short]), + IntegerType -> Seq(classOf[java.lang.Integer], classOf[Int]), + LongType -> Seq(classOf[java.lang.Long], classOf[Long]), + FloatType -> Seq(classOf[java.lang.Float], classOf[Float]), + DoubleType -> Seq(classOf[java.lang.Double], classOf[Double]), + StringType -> Seq(classOf[String]) + ) + + /** + * Returns true if the class can be found and loaded. + */ + private def classExists(className: String): Boolean = { + try { + Utils.classForName(className) + true + } catch { + case e: ClassNotFoundException => false + } + } + + /** + * Finds a Java static method using reflection that matches the given argument types, + * and whose return type is string. + * + * The types sequence must be the valid types defined in [[typeMapping]]. + * + * This is made public for unit testing. + */ + def findMethod(className: String, methodName: String, argTypes: Seq[DataType]): Option[Method] = { + val clazz: Class[_] = Utils.classForName(className) + clazz.getMethods.find { method => + val candidateTypes = method.getParameterTypes + if (method.getName != methodName) { + // Name must match + false + } else if (!Modifier.isStatic(method.getModifiers)) { + // Method must be static + false + } else if (candidateTypes.length != argTypes.length) { + // Argument length must match + false + } else { + // Argument type must match. That is, either the method's argument type matches one of the + // acceptable types defined in typeMapping, or it is a super type of the acceptable types. + candidateTypes.zip(argTypes).forall { case (candidateType, argType) => + typeMapping(argType).exists(candidateType.isAssignableFrom) + } + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala new file mode 100644 index 0000000000..43367c7e14 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.types.{IntegerType, StringType} + +/** A static class for testing purpose. */ +object ReflectStaticClass { + def method1(): String = "m1" + def method2(v1: Int): String = "m" + v1 + def method3(v1: java.lang.Integer): String = "m" + v1 + def method4(v1: Int, v2: String): String = "m" + v1 + v2 +} + +/** A non-static class for testing purpose. */ +class ReflectDynamicClass { + def method1(): String = "m1" +} + +/** + * Test suite for [[CallMethodViaReflection]] and its companion object. + */ +class CallMethodViaReflectionSuite extends SparkFunSuite with ExpressionEvalHelper { + + import CallMethodViaReflection._ + + // Get rid of the $ so we are getting the companion object's name. + private val staticClassName = ReflectStaticClass.getClass.getName.stripSuffix("$") + private val dynamicClassName = classOf[ReflectDynamicClass].getName + + test("findMethod via reflection for static methods") { + assert(findMethod(staticClassName, "method1", Seq.empty).exists(_.getName == "method1")) + assert(findMethod(staticClassName, "method2", Seq(IntegerType)).isDefined) + assert(findMethod(staticClassName, "method3", Seq(IntegerType)).isDefined) + assert(findMethod(staticClassName, "method4", Seq(IntegerType, StringType)).isDefined) + } + + test("findMethod for a JDK library") { + assert(findMethod(classOf[java.util.UUID].getName, "randomUUID", Seq.empty).isDefined) + } + + test("class not found") { + val ret = createExpr("some-random-class", "method").checkInputDataTypes() + assert(ret.isFailure) + val errorMsg = ret.asInstanceOf[TypeCheckFailure].message + assert(errorMsg.contains("not found") && errorMsg.contains("class")) + } + + test("method not found because name does not match") { + val ret = createExpr(staticClassName, "notfoundmethod").checkInputDataTypes() + assert(ret.isFailure) + val errorMsg = ret.asInstanceOf[TypeCheckFailure].message + assert(errorMsg.contains("cannot find a static method")) + } + + test("method not found because there is no static method") { + val ret = createExpr(dynamicClassName, "method1").checkInputDataTypes() + assert(ret.isFailure) + val errorMsg = ret.asInstanceOf[TypeCheckFailure].message + assert(errorMsg.contains("cannot find a static method")) + } + + test("input type checking") { + assert(CallMethodViaReflection(Seq.empty).checkInputDataTypes().isFailure) + assert(CallMethodViaReflection(Seq(Literal(staticClassName))).checkInputDataTypes().isFailure) + assert(CallMethodViaReflection( + Seq(Literal(staticClassName), Literal(1))).checkInputDataTypes().isFailure) + assert(createExpr(staticClassName, "method1").checkInputDataTypes().isSuccess) + } + + test("invoking methods using acceptable types") { + checkEvaluation(createExpr(staticClassName, "method1"), "m1") + checkEvaluation(createExpr(staticClassName, "method2", 2), "m2") + checkEvaluation(createExpr(staticClassName, "method3", 3), "m3") + checkEvaluation(createExpr(staticClassName, "method4", 4, "four"), "m4four") + } + + private def createExpr(className: String, methodName: String, args: Any*) = { + CallMethodViaReflection( + Literal.create(className, StringType) +: + Literal.create(methodName, StringType) +: + args.map(Literal.apply) + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala new file mode 100644 index 0000000000..a5b08f7177 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.test.SharedSQLContext + +class MiscFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("reflect and java_method") { + val df = Seq((1, "one")).toDF("a", "b") + val className = ReflectClass.getClass.getName.stripSuffix("$") + checkAnswer( + df.selectExpr( + s"reflect('$className', 'method1', a, b)", + s"java_method('$className', 'method1', a, b)"), + Row("m1one", "m1one")) + } +} + +object ReflectClass { + def method1(v1: Int, v2: String): String = "m" + v1 + v2 +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 6f36abc4db..b8a75850b1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -235,7 +235,10 @@ private[sql] class HiveSessionCatalog( // parse_url_tuple, posexplode, reflect2, // str_to_map, windowingtablefunction. private val hiveFunctions = Seq( - "hash", "java_method", "histogram_numeric", - "percentile", "percentile_approx", "reflect", "str_to_map" + "hash", + "histogram_numeric", + "percentile", + "percentile_approx", + "str_to_map" ) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index a43f0d0d7e..961d95c268 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -996,29 +996,6 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(sql("SELECT CAST('775983671874188101' as BIGINT)"), Row(775983671874188101L)) } - // `Math.exp(1.0)` has different result for different jdk version, so not use createQueryTest - test("udf_java_method") { - checkAnswer(sql( - """ - |SELECT java_method("java.lang.String", "valueOf", 1), - | java_method("java.lang.String", "isEmpty"), - | java_method("java.lang.Math", "max", 2, 3), - | java_method("java.lang.Math", "min", 2, 3), - | java_method("java.lang.Math", "round", 2.5D), - | java_method("java.lang.Math", "exp", 1.0D), - | java_method("java.lang.Math", "floor", 1.9D) - |FROM src tablesample (1 rows) - """.stripMargin), - Row( - "1", - "true", - java.lang.Math.max(2, 3).toString, - java.lang.Math.min(2, 3).toString, - java.lang.Math.round(2.5).toString, - java.lang.Math.exp(1.0).toString, - java.lang.Math.floor(1.9).toString)) - } - test("dynamic partition value test") { try { sql("set hive.exec.dynamic.partition.mode=nonstrict") -- cgit v1.2.3