aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala67
2 files changed, 44 insertions, 27 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
index 70c6d06cf2..49520b7678 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
@@ -308,13 +308,9 @@ case class StructField(name: String, dataType: DataType, nullable: Boolean) {
object StructType {
protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType =
StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable)))
-
- private def validateFields(fields: Seq[StructField]): Boolean =
- fields.map(field => field.name).distinct.size == fields.size
}
case class StructType(fields: Seq[StructField]) extends DataType {
- require(StructType.validateFields(fields), "Found fields with the same name.")
/**
* Returns all field names in a [[Seq]].
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
index b6b8592344..cc125d539c 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
@@ -17,47 +17,68 @@
package org.apache.spark.sql.hive.execution
-import org.apache.spark.sql.hive.test.TestHive
-import org.apache.hadoop.conf.Configuration
-import org.apache.spark.SparkContext._
+import java.io.{DataOutput, DataInput}
import java.util
-import org.apache.hadoop.fs.{FileSystem, Path}
+import java.util.Properties
+
+import org.apache.spark.util.Utils
+
+import scala.collection.JavaConversions._
+
+import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.hive.serde2.{SerDeStats, AbstractSerDe}
-import org.apache.hadoop.io.{NullWritable, Writable}
+import org.apache.hadoop.io.Writable
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorFactory, ObjectInspector}
-import java.util.Properties
+
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
-import scala.collection.JavaConversions._
-import java.io.{DataOutput, DataInput}
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.hive.test.TestHive._
+
+case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int)
+
/**
* A test suite for Hive custom UDFs.
*/
class HiveUdfSuite extends HiveComparisonTest {
- TestHive.sql(
- """
+ test("spark sql udf test that returns a struct") {
+ registerFunction("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5))
+ assert(sql(
+ """
+ |SELECT getStruct(1).f1,
+ | getStruct(1).f2,
+ | getStruct(1).f3,
+ | getStruct(1).f4,
+ | getStruct(1).f5 FROM src LIMIT 1
+ """.stripMargin).first() === Row(1, 2, 3, 4, 5))
+ }
+
+ test("hive struct udf") {
+ sql(
+ """
|CREATE EXTERNAL TABLE hiveUdfTestTable (
| pair STRUCT<id: INT, value: INT>
|)
|PARTITIONED BY (partition STRING)
|ROW FORMAT SERDE '%s'
|STORED AS SEQUENCEFILE
- """.stripMargin.format(classOf[PairSerDe].getName)
- )
-
- TestHive.sql(
- "ALTER TABLE hiveUdfTestTable ADD IF NOT EXISTS PARTITION(partition='testUdf') LOCATION '%s'"
- .format(this.getClass.getClassLoader.getResource("data/files/testUdf").getFile)
- )
-
- TestHive.sql("CREATE TEMPORARY FUNCTION testUdf AS '%s'".format(classOf[PairUdf].getName))
-
- TestHive.sql("SELECT testUdf(pair) FROM hiveUdfTestTable")
-
- TestHive.sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf")
+ """.
+ stripMargin.format(classOf[PairSerDe].getName))
+
+ val location = Utils.getSparkClassLoader.getResource("data/files/testUdf").getFile
+ sql(s"""
+ ALTER TABLE hiveUdfTestTable
+ ADD IF NOT EXISTS PARTITION(partition='testUdf')
+ LOCATION '$location'""")
+
+ sql(s"CREATE TEMPORARY FUNCTION testUdf AS '${classOf[PairUdf].getName}'")
+ sql("SELECT testUdf(pair) FROM hiveUdfTestTable")
+ sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf")
+ }
}
class TestPair(x: Int, y: Int) extends Writable with Serializable {