aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDilip Biswal <dbiswal@us.ibm.com>2016-01-12 21:41:38 -0800
committerYin Huai <yhuai@databricks.com>2016-01-12 21:41:46 -0800
commitdc7b3870fcfc2723319dbb8c53d721211a8116be (patch)
tree27246f4af2ee60c93a99524f145306b933eb06f6
parentf14922cff84b1e0984ba4597d764615184126bdc (diff)
downloadspark-dc7b3870fcfc2723319dbb8c53d721211a8116be.tar.gz
spark-dc7b3870fcfc2723319dbb8c53d721211a8116be.tar.bz2
spark-dc7b3870fcfc2723319dbb8c53d721211a8116be.zip
[SPARK-12558][SQL] AnalysisException when multiple functions applied in GROUP BY clause
cloud-fan Can you please take a look ? In this case, we are failing during check analysis while validating the aggregation expression. I have added a semanticEquals for HiveGenericUDF to fix this. Please let me know if this is the right way to address this issue. Author: Dilip Biswal <dbiswal@us.ibm.com> Closes #10520 from dilipbiswal/spark-12558.
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala23
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala7
2 files changed, 30 insertions, 0 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
index b8cced0b80..087b0c087c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
@@ -26,11 +26,13 @@ import scala.reflect.ClassTag
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.io.{Input, Output}
+import com.google.common.base.Objects
import org.apache.avro.Schema
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.hive.ql.exec.{UDF, Utilities}
import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc}
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDFMacro
import org.apache.hadoop.hive.serde2.ColumnProjectionUtils
import org.apache.hadoop.hive.serde2.avro.{AvroGenericRecordWritable, AvroSerdeUtils}
import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector
@@ -45,6 +47,7 @@ private[hive] object HiveShim {
// scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs)
val UNLIMITED_DECIMAL_PRECISION = 38
val UNLIMITED_DECIMAL_SCALE = 18
+ val HIVE_GENERIC_UDF_MACRO_CLS = "org.apache.hadoop.hive.ql.udf.generic.GenericUDFMacro"
/*
* This function in hive-0.13 become private, but we have to do this to walkaround hive bug
@@ -123,6 +126,26 @@ private[hive] object HiveShim {
// for Serialization
def this() = this(null)
+ override def hashCode(): Int = {
+ if (functionClassName == HIVE_GENERIC_UDF_MACRO_CLS) {
+ Objects.hashCode(functionClassName, instance.asInstanceOf[GenericUDFMacro].getBody())
+ } else {
+ functionClassName.hashCode()
+ }
+ }
+
+ override def equals(other: Any): Boolean = other match {
+ case a: HiveFunctionWrapper if functionClassName == a.functionClassName =>
+ // In case of udf macro, check to make sure they point to the same underlying UDF
+ if (functionClassName == HIVE_GENERIC_UDF_MACRO_CLS) {
+ a.instance.asInstanceOf[GenericUDFMacro].getBody() ==
+ instance.asInstanceOf[GenericUDFMacro].getBody()
+ } else {
+ true
+ }
+ case _ => false
+ }
+
@transient
def deserializeObjectByKryo[T: ClassTag](
kryo: Kryo,
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index c5ff8825ab..dfe33ba8b0 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -350,6 +350,13 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
sqlContext.dropTempTable("testUDF")
}
+ test("Hive UDF in group by") {
+ Seq(Tuple1(1451400761)).toDF("test_date").registerTempTable("tab1")
+ val count = sql("select date(cast(test_date as timestamp))" +
+ " from tab1 group by date(cast(test_date as timestamp))").count()
+ assert(count == 1)
+ }
+
test("SPARK-11522 select input_file_name from non-parquet table"){
withTempDir { tempDir =>