aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-09-17 11:14:52 -0700
committerMichael Armbrust <michael@databricks.com>2015-09-17 11:14:52 -0700
commitaad644fbe29151aec9004817d42e4928bdb326f3 (patch)
tree77bfa902698f82e6e2547e9bc70dfec46bd0970f /sql/core
parente0dc2bc232206d2f4da4278502c1f88babc8b55a (diff)
downloadspark-aad644fbe29151aec9004817d42e4928bdb326f3.tar.gz
spark-aad644fbe29151aec9004817d42e4928bdb326f3.tar.bz2
spark-aad644fbe29151aec9004817d42e4928bdb326f3.zip
[SPARK-10639] [SQL] Need to convert UDAF's result from scala to sql type
https://issues.apache.org/jira/browse/SPARK-10639 Author: Yin Huai <yhuai@databricks.com> Closes #8788 from yhuai/udafConversion.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala37
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala21
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala11
3 files changed, 60 insertions, 9 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index d43d3dd9ff..1114fe6552 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -40,6 +40,9 @@ sealed trait BufferSetterGetterUtils {
var i = 0
while (i < getters.length) {
getters(i) = dataTypes(i) match {
+ case NullType =>
+ (row: InternalRow, ordinal: Int) => null
+
case BooleanType =>
(row: InternalRow, ordinal: Int) =>
if (row.isNullAt(ordinal)) null else row.getBoolean(ordinal)
@@ -74,6 +77,14 @@ sealed trait BufferSetterGetterUtils {
(row: InternalRow, ordinal: Int) =>
if (row.isNullAt(ordinal)) null else row.getDecimal(ordinal, precision, scale)
+ case DateType =>
+ (row: InternalRow, ordinal: Int) =>
+ if (row.isNullAt(ordinal)) null else row.getInt(ordinal)
+
+ case TimestampType =>
+ (row: InternalRow, ordinal: Int) =>
+ if (row.isNullAt(ordinal)) null else row.getLong(ordinal)
+
case other =>
(row: InternalRow, ordinal: Int) =>
if (row.isNullAt(ordinal)) null else row.get(ordinal, other)
@@ -92,6 +103,9 @@ sealed trait BufferSetterGetterUtils {
var i = 0
while (i < setters.length) {
setters(i) = dataTypes(i) match {
+ case NullType =>
+ (row: MutableRow, ordinal: Int, value: Any) => row.setNullAt(ordinal)
+
case b: BooleanType =>
(row: MutableRow, ordinal: Int, value: Any) =>
if (value != null) {
@@ -151,8 +165,22 @@ sealed trait BufferSetterGetterUtils {
case dt: DecimalType =>
val precision = dt.precision
(row: MutableRow, ordinal: Int, value: Any) =>
+ // To make it work with UnsafeRow, we cannot use setNullAt.
+ // Please see the comment of UnsafeRow's setDecimal.
+ row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision)
+
+ case DateType =>
+ (row: MutableRow, ordinal: Int, value: Any) =>
if (value != null) {
- row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision)
+ row.setInt(ordinal, value.asInstanceOf[Int])
+ } else {
+ row.setNullAt(ordinal)
+ }
+
+ case TimestampType =>
+ (row: MutableRow, ordinal: Int, value: Any) =>
+ if (value != null) {
+ row.setLong(ordinal, value.asInstanceOf[Long])
} else {
row.setNullAt(ordinal)
}
@@ -205,6 +233,7 @@ private[sql] class MutableAggregationBufferImpl (
throw new IllegalArgumentException(
s"Could not access ${i}th value in this buffer because it only has $length values.")
}
+
toScalaConverters(i)(bufferValueGetters(i)(underlyingBuffer, offsets(i)))
}
@@ -352,6 +381,10 @@ private[sql] case class ScalaUDAF(
}
}
+ private[this] lazy val outputToCatalystConverter: Any => Any = {
+ CatalystTypeConverters.createToCatalystConverter(dataType)
+ }
+
// This buffer is only used at executor side.
private[this] var inputAggregateBuffer: InputAggregationBuffer = null
@@ -424,7 +457,7 @@ private[sql] case class ScalaUDAF(
override def eval(buffer: InternalRow): Any = {
evalAggregateBuffer.underlyingInputBuffer = buffer
- udaf.evaluate(evalAggregateBuffer)
+ outputToCatalystConverter(udaf.evaluate(evalAggregateBuffer))
}
override def toString: String = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index cada03e9ac..e3c5a42667 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -115,19 +115,26 @@ object QueryTest {
*/
def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Option[String] = {
val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
+
+ // We need to call prepareRow recursively to handle schemas with struct types.
+ def prepareRow(row: Row): Row = {
+ Row.fromSeq(row.toSeq.map {
+ case null => null
+ case d: java.math.BigDecimal => BigDecimal(d)
+ // Convert array to Seq for easy equality check.
+ case b: Array[_] => b.toSeq
+ case r: Row => prepareRow(r)
+ case o => o
+ })
+ }
+
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
// For BigDecimal type, the Scala type has a better definition of equality test (similar to
// Java's java.math.BigDecimal.compareTo).
// For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for
// equality test.
- val converted: Seq[Row] = answer.map { s =>
- Row.fromSeq(s.toSeq.map {
- case d: java.math.BigDecimal => BigDecimal(d)
- case b: Array[Byte] => b.toSeq
- case o => o
- })
- }
+ val converted: Seq[Row] = answer.map(prepareRow)
if (!isSorted) converted.sortBy(_.toString()) else converted
}
val sparkAnswer = try df.collect().toSeq catch {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 46d87843df..7992fd59ff 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -22,6 +22,7 @@ import scala.beans.{BeanInfo, BeanProperty}
import com.clearspring.analytics.stream.cardinality.HyperLogLog
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
@@ -163,4 +164,14 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext {
assert(new MyDenseVectorUDT().typeName === "mydensevector")
assert(new OpenHashSetUDT(IntegerType).typeName === "openhashset")
}
+
+ test("Catalyst type converter null handling for UDTs") {
+ val udt = new MyDenseVectorUDT()
+ val toScalaConverter = CatalystTypeConverters.createToScalaConverter(udt)
+ assert(toScalaConverter(null) === null)
+
+ val toCatalystConverter = CatalystTypeConverters.createToCatalystConverter(udt)
+ assert(toCatalystConverter(null) === null)
+
+ }
}