aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala222
3 files changed, 125 insertions, 121 deletions
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
index ea1fd23d0d..11e0c120f4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
@@ -24,6 +24,7 @@ import java.math.MathContext
import scala.util.Random
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
/**
* Random data generators for Spark SQL DataTypes. These generators do not generate uniformly random
@@ -106,6 +107,11 @@ object RandomDataGenerator {
case BooleanType => Some(() => rand.nextBoolean())
case DateType => Some(() => new java.sql.Date(rand.nextInt()))
case TimestampType => Some(() => new java.sql.Timestamp(rand.nextLong()))
+ case CalendarIntervalType => Some(() => {
+ val months = rand.nextInt(1000)
+ val ns = rand.nextLong()
+ new CalendarInterval(months, ns)
+ })
case DecimalType.Fixed(precision, scale) => Some(
() => BigDecimal.apply(
rand.nextLong() % math.pow(10, precision).toLong,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index 4c94b3307d..7c591f6143 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -56,11 +56,21 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
assert(leakedShuffleMemory === 0)
taskMemoryManager = null
}
+ TaskContext.unset()
}
test(name) {
taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
shuffleMemoryManager = new TestShuffleMemoryManager
+
+ TaskContext.setTaskContext(new TaskContextImpl(
+ stageId = 0,
+ partitionId = 0,
+ taskAttemptId = Random.nextInt(10000),
+ attemptNumber = 0,
+ taskMemoryManager = taskMemoryManager,
+ metricsSystem = null))
+
try {
f
} catch {
@@ -163,14 +173,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
// Calling this make sure we have block manager and everything else setup.
TestSQLContext
- TaskContext.setTaskContext(new TaskContextImpl(
- stageId = 0,
- partitionId = 0,
- taskAttemptId = 0,
- attemptNumber = 0,
- taskMemoryManager = taskMemoryManager,
- metricsSystem = null))
-
// Memory consumption in the beginning of the task.
val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
index 5d214d7bfc..0282b25b9d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
@@ -19,140 +19,136 @@ package org.apache.spark.sql.execution
import scala.util.Random
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{RowOrdering, UnsafeProjection}
+import org.apache.spark._
+import org.apache.spark.sql.RandomDataGenerator
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, RowOrdering, UnsafeProjection}
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
-import org.apache.spark.unsafe.types.UTF8String
-import org.apache.spark._
+/**
+ * Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data.
+ */
class UnsafeKVExternalSorterSuite extends SparkFunSuite {
- test("sorting string key and int int value") {
-
- // Calling this make sure we have block manager and everything else setup.
- TestSQLContext
+ private val keyTypes = Seq(IntegerType, FloatType, DoubleType, StringType)
+ private val valueTypes = Seq(IntegerType, FloatType, DoubleType, StringType)
- val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
- val shuffleMemMgr = new TestShuffleMemoryManager
+ testKVSorter(new StructType, new StructType, spill = true)
+ testKVSorter(new StructType().add("c1", IntegerType), new StructType, spill = true)
+ testKVSorter(new StructType, new StructType().add("c1", IntegerType), spill = true)
- TaskContext.setTaskContext(new TaskContextImpl(
- stageId = 0,
- partitionId = 0,
- taskAttemptId = 0,
- attemptNumber = 0,
- taskMemoryManager = taskMemMgr,
- metricsSystem = null))
-
- val keySchema = new StructType().add("a", StringType)
- val valueSchema = new StructType().add("b", IntegerType).add("c", IntegerType)
- val sorter = new UnsafeKVExternalSorter(
- keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr,
- 16 * 1024)
-
- val keyConverter = UnsafeProjection.create(keySchema)
- val valueConverter = UnsafeProjection.create(valueSchema)
+ private val rand = new Random(42)
+ for (i <- 0 until 6) {
+ val keySchema = RandomDataGenerator.randomSchema(rand.nextInt(10) + 1, keyTypes)
+ val valueSchema = RandomDataGenerator.randomSchema(rand.nextInt(10) + 1, valueTypes)
+ testKVSorter(keySchema, valueSchema, spill = i > 3)
+ }
- val rand = new Random(42)
- val data = null +: Seq.fill[String](10) {
- Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString
- }
+ /**
+ * Create a test case using randomly generated data for the given key and value schema.
+ *
+ * The approach works as follows:
+ *
+ * - Create input by randomly generating data based on the given schema
+ * - Run [[UnsafeKVExternalSorter]] on the generated data
+ * - Collect the output from the sorter, and make sure the keys are sorted in ascending order
+ * - Sort the input by both key and value, and sort the sorter output also by both key and value.
+ * Compare the sorted input and sorted output together to make sure all the key/values match.
+ *
+ * If spill is set to true, the sorter will spill probabilistically roughly every 100 records.
+ */
+ private def testKVSorter(keySchema: StructType, valueSchema: StructType, spill: Boolean): Unit = {
+
+ val keySchemaStr = keySchema.map(_.dataType.simpleString).mkString("[", ",", "]")
+ val valueSchemaStr = valueSchema.map(_.dataType.simpleString).mkString("[", ",", "]")
+
+ test(s"kv sorting key schema $keySchemaStr and value schema $valueSchemaStr") {
+ // Calling this make sure we have block manager and everything else setup.
+ TestSQLContext
+
+ val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+ val shuffleMemMgr = new TestShuffleMemoryManager
+ TaskContext.setTaskContext(new TaskContextImpl(
+ stageId = 0,
+ partitionId = 0,
+ taskAttemptId = 98456,
+ attemptNumber = 0,
+ taskMemoryManager = taskMemMgr,
+ metricsSystem = null))
+
+ // Create the data converters
+ val kExternalConverter = CatalystTypeConverters.createToCatalystConverter(keySchema)
+ val vExternalConverter = CatalystTypeConverters.createToCatalystConverter(valueSchema)
+ val kConverter = UnsafeProjection.create(keySchema)
+ val vConverter = UnsafeProjection.create(valueSchema)
+
+ val keyDataGen = RandomDataGenerator.forType(keySchema, nullable = false).get
+ val valueDataGen = RandomDataGenerator.forType(valueSchema, nullable = false).get
+
+ val input = Seq.fill(1024) {
+ val k = kConverter(kExternalConverter.apply(keyDataGen.apply()).asInstanceOf[InternalRow])
+ val v = vConverter(vExternalConverter.apply(valueDataGen.apply()).asInstanceOf[InternalRow])
+ (k.asInstanceOf[InternalRow].copy(), v.asInstanceOf[InternalRow].copy())
+ }
- val inputRows = data.map { str =>
- keyConverter.apply(InternalRow(UTF8String.fromString(str))).copy()
- }
+ val sorter = new UnsafeKVExternalSorter(
+ keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr, 16 * 1024 * 1024)
- var i = 0
- data.foreach { str =>
- if (str != null) {
- val k = InternalRow(UTF8String.fromString(str))
- val v = InternalRow(str.length, str.length + 1)
- sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
- } else {
- val k = InternalRow(UTF8String.fromString(str))
- val v = InternalRow(-1, -2)
- sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
+ // Insert generated keys and values into the sorter
+ input.foreach { case (k, v) =>
+ sorter.insertKV(k.asInstanceOf[UnsafeRow], v.asInstanceOf[UnsafeRow])
+ // 1% chance we will spill
+ if (rand.nextDouble() < 0.01 && spill) {
+ shuffleMemMgr.markAsOutOfMemory()
+ sorter.closeCurrentPage()
+ }
}
- if ((i % 100) == 0) {
- shuffleMemMgr.markAsOutOfMemory()
- sorter.closeCurrentPage()
+ // Collect the sorted output
+ val out = new scala.collection.mutable.ArrayBuffer[(InternalRow, InternalRow)]
+ val iter = sorter.sortedIterator()
+ while (iter.next()) {
+ out += Tuple2(iter.getKey.copy(), iter.getValue.copy())
}
- i += 1
- }
- val out = new scala.collection.mutable.ArrayBuffer[InternalRow]
- val iter = sorter.sortedIterator()
- while (iter.next()) {
- if (iter.getKey.getUTF8String(0) == null) {
- withClue(s"for null key") {
- assert(-1 === iter.getValue.getInt(0))
- assert(-2 === iter.getValue.getInt(1))
- }
- } else {
- val key = iter.getKey.getString(0)
- withClue(s"for key $key") {
- assert(key.length === iter.getValue.getInt(0))
- assert(key.length + 1 === iter.getValue.getInt(1))
+ val keyOrdering = RowOrdering.forSchema(keySchema.map(_.dataType))
+ val valueOrdering = RowOrdering.forSchema(valueSchema.map(_.dataType))
+ val kvOrdering = new Ordering[(InternalRow, InternalRow)] {
+ override def compare(x: (InternalRow, InternalRow), y: (InternalRow, InternalRow)): Int = {
+ keyOrdering.compare(x._1, y._1) match {
+ case 0 => valueOrdering.compare(x._2, y._2)
+ case cmp => cmp
+ }
}
}
- out += iter.getKey.copy()
- }
- assert(out === inputRows.sorted(RowOrdering.forSchema(keySchema.map(_.dataType))))
- }
-
- test("sorting arbitrary string data") {
-
- // Calling this make sure we have block manager and everything else setup.
- TestSQLContext
-
- val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
- val shuffleMemMgr = new TestShuffleMemoryManager
-
- TaskContext.setTaskContext(new TaskContextImpl(
- stageId = 0,
- partitionId = 0,
- taskAttemptId = 0,
- attemptNumber = 0,
- taskMemoryManager = taskMemMgr,
- metricsSystem = null))
-
- val keySchema = new StructType().add("a", StringType)
- val valueSchema = new StructType().add("b", IntegerType)
- val sorter = new UnsafeKVExternalSorter(
- keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr,
- 16 * 1024)
-
- val keyConverter = UnsafeProjection.create(keySchema)
- val valueConverter = UnsafeProjection.create(valueSchema)
-
- val rand = new Random(42)
- val data = Seq.fill(512) {
- Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString
- }
+ // Testing to make sure output from the sorter is sorted by key
+ var prevK: InternalRow = null
+ out.zipWithIndex.foreach { case ((k, v), i) =>
+ if (prevK != null) {
+ assert(keyOrdering.compare(prevK, k) <= 0,
+ s"""
+ |key is not in sorted order:
+ |previous key: $prevK
+ |current key : $k
+ """.stripMargin)
+ }
+ prevK = k
+ }
- var i = 0
- data.foreach { str =>
- val k = InternalRow(UTF8String.fromString(str))
- val v = InternalRow(str.length)
- sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
+ // Testing to make sure the key/value in output matches input
+ assert(out.sorted(kvOrdering) === input.sorted(kvOrdering))
- if ((i % 100) == 0) {
- shuffleMemMgr.markAsOutOfMemory()
- sorter.closeCurrentPage()
+ // Make sure there is no memory leak
+ val leakedUnsafeMemory: Long = taskMemMgr.cleanUpAllAllocatedMemory
+ if (shuffleMemMgr != null) {
+ val leakedShuffleMemory: Long = shuffleMemMgr.getMemoryConsumptionForThisTask()
+ assert(0L === leakedShuffleMemory)
}
- i += 1
+ assert(0 === leakedUnsafeMemory)
+ TaskContext.unset()
}
-
- val out = new scala.collection.mutable.ArrayBuffer[String]
- val iter = sorter.sortedIterator()
- while (iter.next()) {
- assert(iter.getKey.getString(0).length === iter.getValue.getInt(0))
- out += iter.getKey.getString(0)
- }
-
- assert(out === data.sorted)
}
}