aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Baptiste Onofré <jbonofre@apache.org>2015-11-20 14:45:40 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-20 14:45:40 -0800
commit03ba56d78f50747710d01c27d409ba2be42ae557 (patch)
tree574da5cb3b1b44ed727e24b2586a78c82a5ce8f0
parent89fd9bd06160fa89dedbf685bfe159ffe4a06ec6 (diff)
downloadspark-03ba56d78f50747710d01c27d409ba2be42ae557.tar.gz
spark-03ba56d78f50747710d01c27d409ba2be42ae557.tar.bz2
spark-03ba56d78f50747710d01c27d409ba2be42ae557.zip
[SPARK-11716][SQL] UDFRegistration just drops the input type when re-creating the UserDefinedFunction
https://issues.apache.org/jira/browse/SPARK-11716 This is one is #9739 and a regression test. When commit it, please make sure the author is jbonofre. You can find the original PR at https://github.com/apache/spark/pull/9739 closes #9739 Author: Jean-Baptiste Onofré <jbonofre@apache.org> Author: Yin Huai <yhuai@databricks.com> Closes #9868 from yhuai/SPARK-11716.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala48
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala15
2 files changed, 39 insertions, 24 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index fc4d0938c5..051694c0d4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -88,7 +88,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
val inputTypes = Try($inputTypes).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
- UserDefinedFunction(func, dataType)
+ UserDefinedFunction(func, dataType, inputTypes)
}""")
}
@@ -120,7 +120,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
val inputTypes = Try(Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
- UserDefinedFunction(func, dataType)
+ UserDefinedFunction(func, dataType, inputTypes)
}
/**
@@ -133,7 +133,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
- UserDefinedFunction(func, dataType)
+ UserDefinedFunction(func, dataType, inputTypes)
}
/**
@@ -146,7 +146,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
- UserDefinedFunction(func, dataType)
+ UserDefinedFunction(func, dataType, inputTypes)
}
/**
@@ -159,7 +159,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
- UserDefinedFunction(func, dataType)
+ UserDefinedFunction(func, dataType, inputTypes)
}
/**
@@ -172,7 +172,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
- UserDefinedFunction(func, dataType)
+ UserDefinedFunction(func, dataType, inputTypes)
}
/**
@@ -185,7 +185,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
- UserDefinedFunction(func, dataType)
+ UserDefinedFunction(func, dataType, inputTypes)
}
/**
@@ -198,7 +198,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
- UserDefinedFunction(func, dataType)
+ UserDefinedFunction(func, dataType, inputTypes)
}
/**
@@ -211,7 +211,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
- UserDefinedFunction(func, dataType)
+ UserDefinedFunction(func, dataType, inputTypes)
}
/**
@@ -224,7 +224,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
- UserDefinedFunction(func, dataType)
+ UserDefinedFunction(func, dataType, inputTypes)
}
/**
@@ -237,7 +237,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
- UserDefinedFunction(func, dataType)
+ UserDefinedFunction(func, dataType, inputTypes)
}
/**
@@ -250,7 +250,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
- UserDefinedFunction(func, dataType)
+ UserDefinedFunction(func, dataType, inputTypes)
}
/**
@@ -263,7 +263,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
- UserDefinedFunction(func, dataType)
+ UserDefinedFunction(func, dataType, inputTypes)
}
/**
@@ -276,7 +276,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
- UserDefinedFunction(func, dataType)
+ UserDefinedFunction(func, dataType, inputTypes)
}
/**
@@ -289,7 +289,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
- UserDefinedFunction(func, dataType)
+ UserDefinedFunction(func, dataType, inputTypes)
}
/**
@@ -302,7 +302,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
- UserDefinedFunction(func, dataType)
+ UserDefinedFunction(func, dataType, inputTypes)
}
/**
@@ -315,7 +315,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
- UserDefinedFunction(func, dataType)
+ UserDefinedFunction(func, dataType, inputTypes)
}
/**
@@ -328,7 +328,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
- UserDefinedFunction(func, dataType)
+ UserDefinedFunction(func, dataType, inputTypes)
}
/**
@@ -341,7 +341,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
- UserDefinedFunction(func, dataType)
+ UserDefinedFunction(func, dataType, inputTypes)
}
/**
@@ -354,7 +354,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
- UserDefinedFunction(func, dataType)
+ UserDefinedFunction(func, dataType, inputTypes)
}
/**
@@ -367,7 +367,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
- UserDefinedFunction(func, dataType)
+ UserDefinedFunction(func, dataType, inputTypes)
}
/**
@@ -380,7 +380,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
- UserDefinedFunction(func, dataType)
+ UserDefinedFunction(func, dataType, inputTypes)
}
/**
@@ -393,7 +393,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
- UserDefinedFunction(func, dataType)
+ UserDefinedFunction(func, dataType, inputTypes)
}
/**
@@ -406,7 +406,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
- UserDefinedFunction(func, dataType)
+ UserDefinedFunction(func, dataType, inputTypes)
}
//////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 9837fa6bdb..fd736718af 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -232,4 +232,19 @@ class UDFSuite extends QueryTest with SharedSQLContext {
| (SELECT complexDataFunc(m, a, b) AS t FROM complexData) tmp
""".stripMargin).toDF(), complexData.select("m", "a", "b"))
}
+
+ test("SPARK-11716 UDFRegistration does not include the input data type in returned UDF") {
+ val myUDF = sqlContext.udf.register("testDataFunc", (n: Int, s: String) => { (n, s.toInt) })
+
+ // Without the fix, this will fail because we fail to cast data type of b to string
+ // because myUDF does not know its input data type. With the fix, this query should not
+ // fail.
+ checkAnswer(
+ testData2.select(myUDF($"a", $"b").as("t")),
+ testData2.selectExpr("struct(a, b)"))
+
+ checkAnswer(
+ sql("SELECT tmp.t.* FROM (SELECT testDataFunc(a, b) AS t from testData2) tmp").toDF(),
+ testData2)
+ }
}