aboutsummaryrefslogtreecommitdiff
path: root/sql/hive
diff options
context:
space:
mode:
authorXi Liu <xil@conviva.com>2014-06-17 13:14:40 +0200
committerMichael Armbrust <michael@databricks.com>2014-06-17 13:19:53 +0200
commit8e6b77fe0773d0b72d2c677bfcfa3323718c7e3c (patch)
tree116ff81ebb58d61234f9a18a29f9058b40bb2411 /sql/hive
parent3d4fa2dab0ade3b7948497a3336fd3e238f93507 (diff)
downloadspark-8e6b77fe0773d0b72d2c677bfcfa3323718c7e3c.tar.gz
spark-8e6b77fe0773d0b72d2c677bfcfa3323718c7e3c.tar.bz2
spark-8e6b77fe0773d0b72d2c677bfcfa3323718c7e3c.zip
[SPARK-2164][SQL] Allow Hive UDF on columns of type struct
Author: Xi Liu <xil@conviva.com> Closes #796 from xiliu82/sqlbug and squashes the following commits: 328dfc4 [Xi Liu] [Spark SQL] remove a temporary function after test 354386a [Xi Liu] [Spark SQL] add test suite for UDF on struct 8fc6f51 [Xi Liu] [SparkSQL] allow UDF on struct (cherry picked from commit f5a4049e534da3c55e1b495ce34155236dfb6dee) Signed-off-by: Michael Armbrust <michael@databricks.com>
Diffstat (limited to 'sql/hive')
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala3
-rwxr-xr-xsql/hive/src/test/resources/data/files/testUdf/part-00000bin0 -> 153 bytes
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala127
3 files changed, 130 insertions, 0 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index 771d2bccf4..ad5e24c62c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -335,6 +335,9 @@ private[hive] trait HiveInspectors {
case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector
case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector
case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector
+ case StructType(fields) =>
+ ObjectInspectorFactory.getStandardStructObjectInspector(
+ fields.map(f => f.name), fields.map(f => toInspector(f.dataType)))
}
def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match {
diff --git a/sql/hive/src/test/resources/data/files/testUdf/part-00000 b/sql/hive/src/test/resources/data/files/testUdf/part-00000
new file mode 100755
index 0000000000..240a5c1a63
--- /dev/null
+++ b/sql/hive/src/test/resources/data/files/testUdf/part-00000
Binary files differ
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
new file mode 100644
index 0000000000..a9e3f42a3a
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import org.apache.spark.sql.hive.test.TestHive
+import org.apache.hadoop.conf.Configuration
+import org.apache.spark.SparkContext._
+import java.util
+import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.hive.serde2.{SerDeStats, AbstractSerDe}
+import org.apache.hadoop.io.{NullWritable, Writable}
+import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorFactory, ObjectInspector}
+import java.util.Properties
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
+import scala.collection.JavaConversions._
+import java.io.{DataOutput, DataInput}
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject
+
+/**
+ * A test suite for Hive custom UDFs.
+ */
+class HiveUdfSuite extends HiveComparisonTest {
+
+ TestHive.hql(
+ """
+ |CREATE EXTERNAL TABLE hiveUdfTestTable (
+ | pair STRUCT<id: INT, value: INT>
+ |)
+ |PARTITIONED BY (partition STRING)
+ |ROW FORMAT SERDE '%s'
+ |STORED AS SEQUENCEFILE
+ """.stripMargin.format(classOf[PairSerDe].getName)
+ )
+
+ TestHive.hql(
+ "ALTER TABLE hiveUdfTestTable ADD IF NOT EXISTS PARTITION(partition='testUdf') LOCATION '%s'"
+ .format(this.getClass.getClassLoader.getResource("data/files/testUdf").getFile)
+ )
+
+ TestHive.hql("CREATE TEMPORARY FUNCTION testUdf AS '%s'".format(classOf[PairUdf].getName))
+
+ TestHive.hql("SELECT testUdf(pair) FROM hiveUdfTestTable")
+
+ TestHive.hql("DROP TEMPORARY FUNCTION IF EXISTS testUdf")
+}
+
+class TestPair(x: Int, y: Int) extends Writable with Serializable {
+ def this() = this(0, 0)
+ var entry: (Int, Int) = (x, y)
+
+ override def write(output: DataOutput): Unit = {
+ output.writeInt(entry._1)
+ output.writeInt(entry._2)
+ }
+
+ override def readFields(input: DataInput): Unit = {
+ val x = input.readInt()
+ val y = input.readInt()
+ entry = (x, y)
+ }
+}
+
+class PairSerDe extends AbstractSerDe {
+ override def initialize(p1: Configuration, p2: Properties): Unit = {}
+
+ override def getObjectInspector: ObjectInspector = {
+ ObjectInspectorFactory
+ .getStandardStructObjectInspector(
+ Seq("pair"),
+ Seq(ObjectInspectorFactory.getStandardStructObjectInspector(
+ Seq("id", "value"),
+ Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector,
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector))
+ ))
+ }
+
+ override def getSerializedClass: Class[_ <: Writable] = classOf[TestPair]
+
+ override def getSerDeStats: SerDeStats = null
+
+ override def serialize(p1: scala.Any, p2: ObjectInspector): Writable = null
+
+ override def deserialize(value: Writable): AnyRef = {
+ val pair = value.asInstanceOf[TestPair]
+
+ val row = new util.ArrayList[util.ArrayList[AnyRef]]
+ row.add(new util.ArrayList[AnyRef](2))
+ row(0).add(Integer.valueOf(pair.entry._1))
+ row(0).add(Integer.valueOf(pair.entry._2))
+
+ row
+ }
+}
+
+class PairUdf extends GenericUDF {
+ override def initialize(p1: Array[ObjectInspector]): ObjectInspector =
+ ObjectInspectorFactory.getStandardStructObjectInspector(
+ Seq("id", "value"),
+ Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, PrimitiveObjectInspectorFactory.javaIntObjectInspector)
+ )
+
+ override def evaluate(args: Array[DeferredObject]): AnyRef = {
+ println("Type = %s".format(args(0).getClass.getName))
+ Integer.valueOf(args(0).get.asInstanceOf[TestPair].entry._2)
+ }
+
+ override def getDisplayString(p1: Array[String]): String = ""
+}
+
+
+