aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala23
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala7
2 files changed, 30 insertions, 0 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
index b8cced0b80..087b0c087c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
@@ -26,11 +26,13 @@ import scala.reflect.ClassTag
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.io.{Input, Output}
+import com.google.common.base.Objects
import org.apache.avro.Schema
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.hive.ql.exec.{UDF, Utilities}
import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc}
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDFMacro
import org.apache.hadoop.hive.serde2.ColumnProjectionUtils
import org.apache.hadoop.hive.serde2.avro.{AvroGenericRecordWritable, AvroSerdeUtils}
import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector
@@ -45,6 +47,7 @@ private[hive] object HiveShim {
// scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs)
val UNLIMITED_DECIMAL_PRECISION = 38
val UNLIMITED_DECIMAL_SCALE = 18
+ val HIVE_GENERIC_UDF_MACRO_CLS = "org.apache.hadoop.hive.ql.udf.generic.GenericUDFMacro"
/*
* This function in hive-0.13 become private, but we have to do this to walkaround hive bug
@@ -123,6 +126,26 @@ private[hive] object HiveShim {
// for Serialization
def this() = this(null)
+ override def hashCode(): Int = {
+ if (functionClassName == HIVE_GENERIC_UDF_MACRO_CLS) {
+ Objects.hashCode(functionClassName, instance.asInstanceOf[GenericUDFMacro].getBody())
+ } else {
+ functionClassName.hashCode()
+ }
+ }
+
+ override def equals(other: Any): Boolean = other match {
+ case a: HiveFunctionWrapper if functionClassName == a.functionClassName =>
+ // In case of udf macro, check to make sure they point to the same underlying UDF
+ if (functionClassName == HIVE_GENERIC_UDF_MACRO_CLS) {
+ a.instance.asInstanceOf[GenericUDFMacro].getBody() ==
+ instance.asInstanceOf[GenericUDFMacro].getBody()
+ } else {
+ true
+ }
+ case _ => false
+ }
+
@transient
def deserializeObjectByKryo[T: ClassTag](
kryo: Kryo,
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index c5ff8825ab..dfe33ba8b0 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -350,6 +350,13 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
sqlContext.dropTempTable("testUDF")
}
+ test("Hive UDF in group by") {
+ Seq(Tuple1(1451400761)).toDF("test_date").registerTempTable("tab1")
+ val count = sql("select date(cast(test_date as timestamp))" +
+ " from tab1 group by date(cast(test_date as timestamp))").count()
+ assert(count == 1)
+ }
+
test("SPARK-11522 select input_file_name from non-parquet table"){
withTempDir { tempDir =>