diff options
Diffstat (limited to 'sql')
3 files changed, 125 insertions, 121 deletions
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index ea1fd23d0d..11e0c120f4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -24,6 +24,7 @@ import java.math.MathContext import scala.util.Random import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval /** * Random data generators for Spark SQL DataTypes. These generators do not generate uniformly random @@ -106,6 +107,11 @@ object RandomDataGenerator { case BooleanType => Some(() => rand.nextBoolean()) case DateType => Some(() => new java.sql.Date(rand.nextInt())) case TimestampType => Some(() => new java.sql.Timestamp(rand.nextLong())) + case CalendarIntervalType => Some(() => { + val months = rand.nextInt(1000) + val ns = rand.nextLong() + new CalendarInterval(months, ns) + }) case DecimalType.Fixed(precision, scale) => Some( () => BigDecimal.apply( rand.nextLong() % math.pow(10, precision).toLong, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 4c94b3307d..7c591f6143 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -56,11 +56,21 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { assert(leakedShuffleMemory === 0) taskMemoryManager = null } + TaskContext.unset() } test(name) { taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) shuffleMemoryManager = new TestShuffleMemoryManager + + TaskContext.setTaskContext(new TaskContextImpl( + stageId = 0, + partitionId = 0, + taskAttemptId = Random.nextInt(10000), + attemptNumber = 0, + taskMemoryManager = taskMemoryManager, + metricsSystem = null)) + try { f } catch { @@ -163,14 +173,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { // Calling this make sure we have block manager and everything else setup. TestSQLContext - TaskContext.setTaskContext(new TaskContextImpl( - stageId = 0, - partitionId = 0, - taskAttemptId = 0, - attemptNumber = 0, - taskMemoryManager = taskMemoryManager, - metricsSystem = null)) - // Memory consumption in the beginning of the task. val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 5d214d7bfc..0282b25b9d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -19,140 +19,136 @@ package org.apache.spark.sql.execution import scala.util.Random -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{RowOrdering, UnsafeProjection} +import org.apache.spark._ +import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, RowOrdering, UnsafeProjection} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} -import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark._ +/** + * Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data. + */ class UnsafeKVExternalSorterSuite extends SparkFunSuite { - test("sorting string key and int int value") { - - // Calling this make sure we have block manager and everything else setup. - TestSQLContext + private val keyTypes = Seq(IntegerType, FloatType, DoubleType, StringType) + private val valueTypes = Seq(IntegerType, FloatType, DoubleType, StringType) - val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) - val shuffleMemMgr = new TestShuffleMemoryManager + testKVSorter(new StructType, new StructType, spill = true) + testKVSorter(new StructType().add("c1", IntegerType), new StructType, spill = true) + testKVSorter(new StructType, new StructType().add("c1", IntegerType), spill = true) - TaskContext.setTaskContext(new TaskContextImpl( - stageId = 0, - partitionId = 0, - taskAttemptId = 0, - attemptNumber = 0, - taskMemoryManager = taskMemMgr, - metricsSystem = null)) - - val keySchema = new StructType().add("a", StringType) - val valueSchema = new StructType().add("b", IntegerType).add("c", IntegerType) - val sorter = new UnsafeKVExternalSorter( - keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr, - 16 * 1024) - - val keyConverter = UnsafeProjection.create(keySchema) - val valueConverter = UnsafeProjection.create(valueSchema) + private val rand = new Random(42) + for (i <- 0 until 6) { + val keySchema = RandomDataGenerator.randomSchema(rand.nextInt(10) + 1, keyTypes) + val valueSchema = RandomDataGenerator.randomSchema(rand.nextInt(10) + 1, valueTypes) + testKVSorter(keySchema, valueSchema, spill = i > 3) + } - val rand = new Random(42) - val data = null +: Seq.fill[String](10) { - Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString - } + /** + * Create a test case using randomly generated data for the given key and value schema. + * + * The approach works as follows: + * + * - Create input by randomly generating data based on the given schema + * - Run [[UnsafeKVExternalSorter]] on the generated data + * - Collect the output from the sorter, and make sure the keys are sorted in ascending order + * - Sort the input by both key and value, and sort the sorter output also by both key and value. + * Compare the sorted input and sorted output together to make sure all the key/values match. + * + * If spill is set to true, the sorter will spill probabilistically roughly every 100 records. + */ + private def testKVSorter(keySchema: StructType, valueSchema: StructType, spill: Boolean): Unit = { + + val keySchemaStr = keySchema.map(_.dataType.simpleString).mkString("[", ",", "]") + val valueSchemaStr = valueSchema.map(_.dataType.simpleString).mkString("[", ",", "]") + + test(s"kv sorting key schema $keySchemaStr and value schema $valueSchemaStr") { + // Calling this make sure we have block manager and everything else setup. + TestSQLContext + + val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) + val shuffleMemMgr = new TestShuffleMemoryManager + TaskContext.setTaskContext(new TaskContextImpl( + stageId = 0, + partitionId = 0, + taskAttemptId = 98456, + attemptNumber = 0, + taskMemoryManager = taskMemMgr, + metricsSystem = null)) + + // Create the data converters + val kExternalConverter = CatalystTypeConverters.createToCatalystConverter(keySchema) + val vExternalConverter = CatalystTypeConverters.createToCatalystConverter(valueSchema) + val kConverter = UnsafeProjection.create(keySchema) + val vConverter = UnsafeProjection.create(valueSchema) + + val keyDataGen = RandomDataGenerator.forType(keySchema, nullable = false).get + val valueDataGen = RandomDataGenerator.forType(valueSchema, nullable = false).get + + val input = Seq.fill(1024) { + val k = kConverter(kExternalConverter.apply(keyDataGen.apply()).asInstanceOf[InternalRow]) + val v = vConverter(vExternalConverter.apply(valueDataGen.apply()).asInstanceOf[InternalRow]) + (k.asInstanceOf[InternalRow].copy(), v.asInstanceOf[InternalRow].copy()) + } - val inputRows = data.map { str => - keyConverter.apply(InternalRow(UTF8String.fromString(str))).copy() - } + val sorter = new UnsafeKVExternalSorter( + keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr, 16 * 1024 * 1024) - var i = 0 - data.foreach { str => - if (str != null) { - val k = InternalRow(UTF8String.fromString(str)) - val v = InternalRow(str.length, str.length + 1) - sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v)) - } else { - val k = InternalRow(UTF8String.fromString(str)) - val v = InternalRow(-1, -2) - sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v)) + // Insert generated keys and values into the sorter + input.foreach { case (k, v) => + sorter.insertKV(k.asInstanceOf[UnsafeRow], v.asInstanceOf[UnsafeRow]) + // 1% chance we will spill + if (rand.nextDouble() < 0.01 && spill) { + shuffleMemMgr.markAsOutOfMemory() + sorter.closeCurrentPage() + } } - if ((i % 100) == 0) { - shuffleMemMgr.markAsOutOfMemory() - sorter.closeCurrentPage() + // Collect the sorted output + val out = new scala.collection.mutable.ArrayBuffer[(InternalRow, InternalRow)] + val iter = sorter.sortedIterator() + while (iter.next()) { + out += Tuple2(iter.getKey.copy(), iter.getValue.copy()) } - i += 1 - } - val out = new scala.collection.mutable.ArrayBuffer[InternalRow] - val iter = sorter.sortedIterator() - while (iter.next()) { - if (iter.getKey.getUTF8String(0) == null) { - withClue(s"for null key") { - assert(-1 === iter.getValue.getInt(0)) - assert(-2 === iter.getValue.getInt(1)) - } - } else { - val key = iter.getKey.getString(0) - withClue(s"for key $key") { - assert(key.length === iter.getValue.getInt(0)) - assert(key.length + 1 === iter.getValue.getInt(1)) + val keyOrdering = RowOrdering.forSchema(keySchema.map(_.dataType)) + val valueOrdering = RowOrdering.forSchema(valueSchema.map(_.dataType)) + val kvOrdering = new Ordering[(InternalRow, InternalRow)] { + override def compare(x: (InternalRow, InternalRow), y: (InternalRow, InternalRow)): Int = { + keyOrdering.compare(x._1, y._1) match { + case 0 => valueOrdering.compare(x._2, y._2) + case cmp => cmp + } } } - out += iter.getKey.copy() - } - assert(out === inputRows.sorted(RowOrdering.forSchema(keySchema.map(_.dataType)))) - } - - test("sorting arbitrary string data") { - - // Calling this make sure we have block manager and everything else setup. - TestSQLContext - - val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) - val shuffleMemMgr = new TestShuffleMemoryManager - - TaskContext.setTaskContext(new TaskContextImpl( - stageId = 0, - partitionId = 0, - taskAttemptId = 0, - attemptNumber = 0, - taskMemoryManager = taskMemMgr, - metricsSystem = null)) - - val keySchema = new StructType().add("a", StringType) - val valueSchema = new StructType().add("b", IntegerType) - val sorter = new UnsafeKVExternalSorter( - keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr, - 16 * 1024) - - val keyConverter = UnsafeProjection.create(keySchema) - val valueConverter = UnsafeProjection.create(valueSchema) - - val rand = new Random(42) - val data = Seq.fill(512) { - Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString - } + // Testing to make sure output from the sorter is sorted by key + var prevK: InternalRow = null + out.zipWithIndex.foreach { case ((k, v), i) => + if (prevK != null) { + assert(keyOrdering.compare(prevK, k) <= 0, + s""" + |key is not in sorted order: + |previous key: $prevK + |current key : $k + """.stripMargin) + } + prevK = k + } - var i = 0 - data.foreach { str => - val k = InternalRow(UTF8String.fromString(str)) - val v = InternalRow(str.length) - sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v)) + // Testing to make sure the key/value in output matches input + assert(out.sorted(kvOrdering) === input.sorted(kvOrdering)) - if ((i % 100) == 0) { - shuffleMemMgr.markAsOutOfMemory() - sorter.closeCurrentPage() + // Make sure there is no memory leak + val leakedUnsafeMemory: Long = taskMemMgr.cleanUpAllAllocatedMemory + if (shuffleMemMgr != null) { + val leakedShuffleMemory: Long = shuffleMemMgr.getMemoryConsumptionForThisTask() + assert(0L === leakedShuffleMemory) } - i += 1 + assert(0 === leakedUnsafeMemory) + TaskContext.unset() } - - val out = new scala.collection.mutable.ArrayBuffer[String] - val iter = sorter.sortedIterator() - while (iter.next()) { - assert(iter.getKey.getString(0).length === iter.getValue.getInt(0)) - out += iter.getKey.getString(0) - } - - assert(out === data.sorted) } } |