aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala7
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala37
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala21
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala11
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala108
6 files changed, 188 insertions, 12 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 966623ed01..f25591794a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -138,8 +138,13 @@ object CatalystTypeConverters {
private case class UDTConverter(
udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] {
+ // toCatalyst (it calls toCatalystImpl) will do null check.
override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue)
- override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue)
+
+ override def toScala(catalystValue: Any): Any = {
+ if (catalystValue == null) null else udt.deserialize(catalystValue)
+ }
+
override def toScalaImpl(row: InternalRow, column: Int): Any =
toScala(row.get(column, udt.sqlType))
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
index 4025cbcec1..e48395028e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
@@ -108,7 +108,21 @@ object RandomDataGenerator {
arr
})
case BooleanType => Some(() => rand.nextBoolean())
- case DateType => Some(() => new java.sql.Date(rand.nextInt()))
+ case DateType =>
+ val generator =
+ () => {
+ var milliseconds = rand.nextLong() % 253402329599999L
+ // -62135740800000L is the number of milliseconds before January 1, 1970, 00:00:00 GMT
+ // for "0001-01-01 00:00:00.000000". We need to find a
+ // number that is greater or equals to this number as a valid timestamp value.
+ while (milliseconds < -62135740800000L) {
+ // 253402329599999L is the the number of milliseconds since
+ // January 1, 1970, 00:00:00 GMT for "9999-12-31 23:59:59.999999".
+ milliseconds = rand.nextLong() % 253402329599999L
+ }
+ DateTimeUtils.toJavaDate((milliseconds / DateTimeUtils.MILLIS_PER_DAY).toInt)
+ }
+ Some(generator)
case TimestampType =>
val generator =
() => {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index d43d3dd9ff..1114fe6552 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -40,6 +40,9 @@ sealed trait BufferSetterGetterUtils {
var i = 0
while (i < getters.length) {
getters(i) = dataTypes(i) match {
+ case NullType =>
+ (row: InternalRow, ordinal: Int) => null
+
case BooleanType =>
(row: InternalRow, ordinal: Int) =>
if (row.isNullAt(ordinal)) null else row.getBoolean(ordinal)
@@ -74,6 +77,14 @@ sealed trait BufferSetterGetterUtils {
(row: InternalRow, ordinal: Int) =>
if (row.isNullAt(ordinal)) null else row.getDecimal(ordinal, precision, scale)
+ case DateType =>
+ (row: InternalRow, ordinal: Int) =>
+ if (row.isNullAt(ordinal)) null else row.getInt(ordinal)
+
+ case TimestampType =>
+ (row: InternalRow, ordinal: Int) =>
+ if (row.isNullAt(ordinal)) null else row.getLong(ordinal)
+
case other =>
(row: InternalRow, ordinal: Int) =>
if (row.isNullAt(ordinal)) null else row.get(ordinal, other)
@@ -92,6 +103,9 @@ sealed trait BufferSetterGetterUtils {
var i = 0
while (i < setters.length) {
setters(i) = dataTypes(i) match {
+ case NullType =>
+ (row: MutableRow, ordinal: Int, value: Any) => row.setNullAt(ordinal)
+
case b: BooleanType =>
(row: MutableRow, ordinal: Int, value: Any) =>
if (value != null) {
@@ -151,8 +165,22 @@ sealed trait BufferSetterGetterUtils {
case dt: DecimalType =>
val precision = dt.precision
(row: MutableRow, ordinal: Int, value: Any) =>
+ // To make it work with UnsafeRow, we cannot use setNullAt.
+ // Please see the comment of UnsafeRow's setDecimal.
+ row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision)
+
+ case DateType =>
+ (row: MutableRow, ordinal: Int, value: Any) =>
if (value != null) {
- row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision)
+ row.setInt(ordinal, value.asInstanceOf[Int])
+ } else {
+ row.setNullAt(ordinal)
+ }
+
+ case TimestampType =>
+ (row: MutableRow, ordinal: Int, value: Any) =>
+ if (value != null) {
+ row.setLong(ordinal, value.asInstanceOf[Long])
} else {
row.setNullAt(ordinal)
}
@@ -205,6 +233,7 @@ private[sql] class MutableAggregationBufferImpl (
throw new IllegalArgumentException(
s"Could not access ${i}th value in this buffer because it only has $length values.")
}
+
toScalaConverters(i)(bufferValueGetters(i)(underlyingBuffer, offsets(i)))
}
@@ -352,6 +381,10 @@ private[sql] case class ScalaUDAF(
}
}
+ private[this] lazy val outputToCatalystConverter: Any => Any = {
+ CatalystTypeConverters.createToCatalystConverter(dataType)
+ }
+
// This buffer is only used at executor side.
private[this] var inputAggregateBuffer: InputAggregationBuffer = null
@@ -424,7 +457,7 @@ private[sql] case class ScalaUDAF(
override def eval(buffer: InternalRow): Any = {
evalAggregateBuffer.underlyingInputBuffer = buffer
- udaf.evaluate(evalAggregateBuffer)
+ outputToCatalystConverter(udaf.evaluate(evalAggregateBuffer))
}
override def toString: String = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index cada03e9ac..e3c5a42667 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -115,19 +115,26 @@ object QueryTest {
*/
def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Option[String] = {
val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
+
+ // We need to call prepareRow recursively to handle schemas with struct types.
+ def prepareRow(row: Row): Row = {
+ Row.fromSeq(row.toSeq.map {
+ case null => null
+ case d: java.math.BigDecimal => BigDecimal(d)
+ // Convert array to Seq for easy equality check.
+ case b: Array[_] => b.toSeq
+ case r: Row => prepareRow(r)
+ case o => o
+ })
+ }
+
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
// For BigDecimal type, the Scala type has a better definition of equality test (similar to
// Java's java.math.BigDecimal.compareTo).
// For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for
// equality test.
- val converted: Seq[Row] = answer.map { s =>
- Row.fromSeq(s.toSeq.map {
- case d: java.math.BigDecimal => BigDecimal(d)
- case b: Array[Byte] => b.toSeq
- case o => o
- })
- }
+ val converted: Seq[Row] = answer.map(prepareRow)
if (!isSorted) converted.sortBy(_.toString()) else converted
}
val sparkAnswer = try df.collect().toSeq catch {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 46d87843df..7992fd59ff 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -22,6 +22,7 @@ import scala.beans.{BeanInfo, BeanProperty}
import com.clearspring.analytics.stream.cardinality.HyperLogLog
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
@@ -163,4 +164,14 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext {
assert(new MyDenseVectorUDT().typeName === "mydensevector")
assert(new OpenHashSetUDT(IntegerType).typeName === "openhashset")
}
+
+ test("Catalyst type converter null handling for UDTs") {
+ val udt = new MyDenseVectorUDT()
+ val toScalaConverter = CatalystTypeConverters.createToScalaConverter(udt)
+ assert(toScalaConverter(null) === null)
+
+ val toCatalystConverter = CatalystTypeConverters.createToCatalystConverter(udt)
+ assert(toCatalystConverter(null) === null)
+
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index a73b1bd52c..24b1846923 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -17,13 +17,55 @@
package org.apache.spark.sql.hive.execution
+import scala.collection.JavaConverters._
+
import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.aggregate
+import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SQLTestUtils
-import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
+import org.apache.spark.sql.types._
import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
import org.apache.spark.sql.hive.test.TestHiveSingleton
+class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction {
+
+ def inputSchema: StructType = schema
+
+ def bufferSchema: StructType = schema
+
+ def dataType: DataType = schema
+
+ def deterministic: Boolean = true
+
+ def initialize(buffer: MutableAggregationBuffer): Unit = {
+ (0 until schema.length).foreach { i =>
+ buffer.update(i, null)
+ }
+ }
+
+ def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
+ if (!input.isNullAt(0) && input.getInt(0) == 50) {
+ (0 until schema.length).foreach { i =>
+ buffer.update(i, input.get(i))
+ }
+ }
+ }
+
+ def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
+ if (!buffer2.isNullAt(0) && buffer2.getInt(0) == 50) {
+ (0 until schema.length).foreach { i =>
+ buffer1.update(i, buffer2.get(i))
+ }
+ }
+ }
+
+ def evaluate(buffer: Row): Any = {
+ Row.fromSeq(buffer.toSeq)
+ }
+}
+
abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
import testImplicits._
@@ -508,6 +550,70 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
assert(errorMessage.contains("implemented based on the new Aggregate Function interface"))
}
}
+
+ test("udaf with all data types") {
+ val struct =
+ StructType(
+ StructField("f1", FloatType, true) ::
+ StructField("f2", ArrayType(BooleanType), true) :: Nil)
+ val dataTypes = Seq(StringType, BinaryType, NullType, BooleanType,
+ ByteType, ShortType, IntegerType, LongType,
+ FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5),
+ DateType, TimestampType,
+ ArrayType(IntegerType), MapType(StringType, LongType), struct,
+ new MyDenseVectorUDT())
+ // Right now, we will use SortBasedAggregate to handle UDAFs.
+ // UnsafeRow.mutableFieldTypes.asScala.toSeq will trigger SortBasedAggregate to use
+ // UnsafeRow as the aggregation buffer. While, dataTypes will trigger
+ // SortBasedAggregate to use a safe row as the aggregation buffer.
+ Seq(dataTypes, UnsafeRow.mutableFieldTypes.asScala.toSeq).foreach { dataTypes =>
+ val fields = dataTypes.zipWithIndex.map { case (dataType, index) =>
+ StructField(s"col$index", dataType, nullable = true)
+ }
+ // The schema used for data generator.
+ val schemaForGenerator = StructType(fields)
+ // The schema used for the DataFrame df.
+ val schema = StructType(StructField("id", IntegerType) +: fields)
+
+ logInfo(s"Testing schema: ${schema.treeString}")
+
+ val udaf = new ScalaAggregateFunction(schema)
+ // Generate data at the driver side. We need to materialize the data first and then
+ // create RDD.
+ val maybeDataGenerator =
+ RandomDataGenerator.forType(
+ dataType = schemaForGenerator,
+ nullable = true,
+ seed = Some(System.nanoTime()))
+ val dataGenerator =
+ maybeDataGenerator
+ .getOrElse(fail(s"Failed to create data generator for schema $schemaForGenerator"))
+ val data = (1 to 50).map { i =>
+ dataGenerator.apply() match {
+ case row: Row => Row.fromSeq(i +: row.toSeq)
+ case null => Row.fromSeq(i +: Seq.fill(schemaForGenerator.length)(null))
+ case other =>
+ fail(s"Row or null is expected to be generated, " +
+ s"but a ${other.getClass.getCanonicalName} is generated.")
+ }
+ }
+
+ // Create a DF for the schema with random data.
+ val rdd = sqlContext.sparkContext.parallelize(data, 1)
+ val df = sqlContext.createDataFrame(rdd, schema)
+
+ val allColumns = df.schema.fields.map(f => col(f.name))
+ val expectedAnaswer =
+ data
+ .find(r => r.getInt(0) == 50)
+ .getOrElse(fail("A row with id 50 should be the expected answer."))
+ checkAnswer(
+ df.groupBy().agg(udaf(allColumns: _*)),
+ // udaf returns a Row as the output value.
+ Row(expectedAnaswer)
+ )
+ }
+ }
}
class SortBasedAggregationQuerySuite extends AggregationQuerySuite {