aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/pom.xml5
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContextImpl.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala19
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala22
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala16
-rw-r--r--core/src/test/java/org/apache/spark/JavaAPISuite.java2
-rw-r--r--core/src/test/scala/org/apache/spark/CacheManagerSuite.scala8
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala6
-rw-r--r--pom.xml2
-rw-r--r--project/SparkBuild.scala7
-rw-r--r--sql/catalyst/pom.xml5
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java259
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java435
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala223
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala119
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala153
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala60
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala2
-rw-r--r--unsafe/pom.xml69
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java87
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java56
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java78
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java105
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java129
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java96
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java549
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java39
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java58
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java35
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java33
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java63
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java54
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java237
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java39
-rw-r--r--unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java38
-rw-r--r--unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java82
-rw-r--r--unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java119
-rw-r--r--unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java250
-rw-r--r--unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java29
-rw-r--r--unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java29
-rw-r--r--unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java41
47 files changed, 3675 insertions, 18 deletions
diff --git a/core/pom.xml b/core/pom.xml
index 459ef66712..2dfb00d7ec 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -96,6 +96,11 @@
<version>${project.version}</version>
</dependency>
<dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-unsafe_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
<groupId>net.java.dev.jets3t</groupId>
<artifactId>jets3t</artifactId>
</dependency>
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 959aefabd8..0c4d28f786 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -40,6 +40,7 @@ import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinato
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
import org.apache.spark.storage._
+import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator}
import org.apache.spark.util.{RpcUtils, Utils}
/**
@@ -69,6 +70,7 @@ class SparkEnv (
val sparkFilesDir: String,
val metricsSystem: MetricsSystem,
val shuffleMemoryManager: ShuffleMemoryManager,
+ val executorMemoryManager: ExecutorMemoryManager,
val outputCommitCoordinator: OutputCommitCoordinator,
val conf: SparkConf) extends Logging {
@@ -382,6 +384,15 @@ object SparkEnv extends Logging {
new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator))
outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef)
+ val executorMemoryManager: ExecutorMemoryManager = {
+ val allocator = if (conf.getBoolean("spark.unsafe.offHeap", false)) {
+ MemoryAllocator.UNSAFE
+ } else {
+ MemoryAllocator.HEAP
+ }
+ new ExecutorMemoryManager(allocator)
+ }
+
val envInstance = new SparkEnv(
executorId,
rpcEnv,
@@ -398,6 +409,7 @@ object SparkEnv extends Logging {
sparkFilesDir,
metricsSystem,
shuffleMemoryManager,
+ executorMemoryManager,
outputCommitCoordinator,
conf)
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index 7d7fe1a446..d09e17dea0 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -21,6 +21,7 @@ import java.io.Serializable
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util.TaskCompletionListener
@@ -133,4 +134,9 @@ abstract class TaskContext extends Serializable {
/** ::DeveloperApi:: */
@DeveloperApi
def taskMetrics(): TaskMetrics
+
+ /**
+ * Returns the manager for this task's managed memory.
+ */
+ private[spark] def taskMemoryManager(): TaskMemoryManager
}
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index 337c8e4ebe..b4d572cb52 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -18,6 +18,7 @@
package org.apache.spark
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}
import scala.collection.mutable.ArrayBuffer
@@ -27,6 +28,7 @@ private[spark] class TaskContextImpl(
val partitionId: Int,
override val taskAttemptId: Long,
override val attemptNumber: Int,
+ override val taskMemoryManager: TaskMemoryManager,
val runningLocally: Boolean = false,
val taskMetrics: TaskMetrics = TaskMetrics.empty)
extends TaskContext
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index f57e215c3f..dd1c48e6cb 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -32,6 +32,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
+import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util._
/**
@@ -178,6 +179,7 @@ private[spark] class Executor(
}
override def run(): Unit = {
+ val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager)
val deserializeStartTime = System.currentTimeMillis()
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = env.closureSerializer.newInstance()
@@ -190,6 +192,7 @@ private[spark] class Executor(
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
updateDependencies(taskFiles, taskJars)
task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
+ task.setTaskMemoryManager(taskMemoryManager)
// If this task has been killed before we deserialized it, let's quit now. Otherwise,
// continue executing the task.
@@ -206,7 +209,21 @@ private[spark] class Executor(
// Run the actual task and measure its runtime.
taskStart = System.currentTimeMillis()
- val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
+ val value = try {
+ task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
+ } finally {
+ // Note: this memory freeing logic is duplicated in DAGScheduler.runLocallyWithinThread;
+ // when changing this, make sure to update both copies.
+ val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
+ if (freedMemory > 0) {
+ val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
+ if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
+ throw new SparkException(errMsg)
+ } else {
+ logError(errMsg)
+ }
+ }
+ }
val taskFinish = System.currentTimeMillis()
// If the task has been killed, let's fail it.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 8c4bff4e83..b7901c06a1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -34,6 +34,7 @@ import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage._
+import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util._
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
@@ -643,8 +644,15 @@ class DAGScheduler(
try {
val rdd = job.finalStage.rdd
val split = rdd.partitions(job.partitions(0))
- val taskContext = new TaskContextImpl(job.finalStage.id, job.partitions(0), taskAttemptId = 0,
- attemptNumber = 0, runningLocally = true)
+ val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager)
+ val taskContext =
+ new TaskContextImpl(
+ job.finalStage.id,
+ job.partitions(0),
+ taskAttemptId = 0,
+ attemptNumber = 0,
+ taskMemoryManager = taskMemoryManager,
+ runningLocally = true)
TaskContext.setTaskContext(taskContext)
try {
val result = job.func(taskContext, rdd.iterator(split, taskContext))
@@ -652,6 +660,16 @@ class DAGScheduler(
} finally {
taskContext.markTaskCompleted()
TaskContext.unset()
+ // Note: this memory freeing logic is duplicated in Executor.run(); when changing this,
+ // make sure to update both copies.
+ val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
+ if (freedMemory > 0) {
+ if (sc.getConf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
+ throw new SparkException(s"Managed memory leak detected; size = $freedMemory bytes")
+ } else {
+ logError(s"Managed memory leak detected; size = $freedMemory bytes")
+ }
+ }
}
} catch {
case e: Exception =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index b09b19e2ac..586d1e0620 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -25,6 +25,7 @@ import scala.collection.mutable.HashMap
import org.apache.spark.{TaskContextImpl, TaskContext}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.serializer.SerializerInstance
+import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util.ByteBufferInputStream
import org.apache.spark.util.Utils
@@ -52,8 +53,13 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
* @return the result of the task
*/
final def run(taskAttemptId: Long, attemptNumber: Int): T = {
- context = new TaskContextImpl(stageId = stageId, partitionId = partitionId,
- taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false)
+ context = new TaskContextImpl(
+ stageId = stageId,
+ partitionId = partitionId,
+ taskAttemptId = taskAttemptId,
+ attemptNumber = attemptNumber,
+ taskMemoryManager = taskMemoryManager,
+ runningLocally = false)
TaskContext.setTaskContext(context)
context.taskMetrics.setHostname(Utils.localHostName())
taskThread = Thread.currentThread()
@@ -68,6 +74,12 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
}
}
+ private var taskMemoryManager: TaskMemoryManager = _
+
+ def setTaskMemoryManager(taskMemoryManager: TaskMemoryManager): Unit = {
+ this.taskMemoryManager = taskMemoryManager
+ }
+
def runTask(context: TaskContext): T
def preferredLocations: Seq[TaskLocation] = Nil
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 8a4f2a08fe..34ac9361d4 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -1009,7 +1009,7 @@ public class JavaAPISuite implements Serializable {
@Test
public void iterator() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
- TaskContext context = new TaskContextImpl(0, 0, 0L, 0, false, new TaskMetrics());
+ TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, false, new TaskMetrics());
Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
}
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
index 70529d9216..668ddf9f5f 100644
--- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -65,7 +65,7 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf
// in blockManager.put is a losing battle. You have been warned.
blockManager = sc.env.blockManager
cacheManager = sc.env.cacheManager
- val context = new TaskContextImpl(0, 0, 0, 0)
+ val context = new TaskContextImpl(0, 0, 0, 0, null)
val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
val getValue = blockManager.get(RDDBlockId(rdd.id, split.index))
assert(computeValue.toList === List(1, 2, 3, 4))
@@ -77,7 +77,7 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf
val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12)
when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result))
- val context = new TaskContextImpl(0, 0, 0, 0)
+ val context = new TaskContextImpl(0, 0, 0, 0, null)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(5, 6, 7))
}
@@ -86,14 +86,14 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf
// Local computation should not persist the resulting value, so don't expect a put().
when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None)
- val context = new TaskContextImpl(0, 0, 0, 0, true)
+ val context = new TaskContextImpl(0, 0, 0, 0, null, true)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}
test("verify task metrics updated correctly") {
cacheManager = sc.env.cacheManager
- val context = new TaskContextImpl(0, 0, 0, 0)
+ val context = new TaskContextImpl(0, 0, 0, 0, null)
cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY)
assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2)
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
index aea76c1adc..85eb2a1d07 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
@@ -176,7 +176,7 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext {
}
val hadoopPart1 = generateFakeHadoopPartition()
val pipedRdd = new PipedRDD(nums, "printenv " + varName)
- val tContext = new TaskContextImpl(0, 0, 0, 0)
+ val tContext = new TaskContextImpl(0, 0, 0, 0, null)
val rddIter = pipedRdd.compute(hadoopPart1, tContext)
val arr = rddIter.toArray
assert(arr(0) == "/some/path")
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 057e226916..83ae870124 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -51,7 +51,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte
}
test("all TaskCompletionListeners should be called even if some fail") {
- val context = new TaskContextImpl(0, 0, 0, 0)
+ val context = new TaskContextImpl(0, 0, 0, 0, null)
val listener = mock(classOf[TaskCompletionListener])
context.addTaskCompletionListener(_ => throw new Exception("blah"))
context.addTaskCompletionListener(listener)
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index 37b593b2c5..2080c432d7 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -89,7 +89,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
)
val iterator = new ShuffleBlockFetcherIterator(
- new TaskContextImpl(0, 0, 0, 0),
+ new TaskContextImpl(0, 0, 0, 0, null),
transfer,
blockManager,
blocksByAddress,
@@ -154,7 +154,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
- val taskContext = new TaskContextImpl(0, 0, 0, 0)
+ val taskContext = new TaskContextImpl(0, 0, 0, 0, null)
val iterator = new ShuffleBlockFetcherIterator(
taskContext,
transfer,
@@ -217,7 +217,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
- val taskContext = new TaskContextImpl(0, 0, 0, 0)
+ val taskContext = new TaskContextImpl(0, 0, 0, 0, null)
val iterator = new ShuffleBlockFetcherIterator(
taskContext,
transfer,
diff --git a/pom.xml b/pom.xml
index 928f5d0f5e..c85c5feeaf 100644
--- a/pom.xml
+++ b/pom.xml
@@ -97,6 +97,7 @@
<module>sql/catalyst</module>
<module>sql/core</module>
<module>sql/hive</module>
+ <module>unsafe</module>
<module>assembly</module>
<module>external/twitter</module>
<module>external/flume</module>
@@ -1215,6 +1216,7 @@
<spark.ui.enabled>false</spark.ui.enabled>
<spark.ui.showConsoleProgress>false</spark.ui.showConsoleProgress>
<spark.driver.allowMultipleContexts>true</spark.driver.allowMultipleContexts>
+ <spark.unsafe.exceptionOnMemoryLeak>true</spark.unsafe.exceptionOnMemoryLeak>
</systemProperties>
<failIfNoTests>false</failIfNoTests>
</configuration>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 09b4976d10..b7dbcd9bc5 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -34,11 +34,11 @@ object BuildCommons {
val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl,
sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingKafka,
- streamingMqtt, streamingTwitter, streamingZeromq, launcher) =
+ streamingMqtt, streamingTwitter, streamingZeromq, launcher, unsafe) =
Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl",
"sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink",
"streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter",
- "streaming-zeromq", "launcher").map(ProjectRef(buildLocation, _))
+ "streaming-zeromq", "launcher", "unsafe").map(ProjectRef(buildLocation, _))
val optionallyEnabledProjects@Seq(yarn, yarnStable, java8Tests, sparkGangliaLgpl,
sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl",
@@ -159,7 +159,7 @@ object SparkBuild extends PomBuild {
// TODO: Add Sql to mima checks
// TODO: remove launcher from this list after 1.3.
allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl,
- networkCommon, networkShuffle, networkYarn, launcher).contains(x)).foreach {
+ networkCommon, networkShuffle, networkYarn, launcher, unsafe).contains(x)).foreach {
x => enable(MimaBuild.mimaSettings(sparkHome, x))(x)
}
@@ -496,6 +496,7 @@ object TestSettings {
javaOptions in Test += "-Dspark.ui.enabled=false",
javaOptions in Test += "-Dspark.ui.showConsoleProgress=false",
javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true",
+ javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true",
javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true",
javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark")
.map { case (k,v) => s"-D$k=$v" }.toSeq,
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 3dea2ee765..5c322d032d 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -51,6 +51,11 @@
<version>${project.version}</version>
</dependency>
<dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-unsafe_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
<groupId>org.scalacheck</groupId>
<artifactId>scalacheck_${scala.binary.version}</artifactId>
<scope>test</scope>
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
new file mode 100644
index 0000000000..299ff3728a
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
@@ -0,0 +1,259 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions;
+
+import java.util.Arrays;
+import java.util.Iterator;
+
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.map.BytesToBytesMap;
+import org.apache.spark.unsafe.memory.MemoryLocation;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+/**
+ * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width.
+ *
+ * This map supports a maximum of 2 billion keys.
+ */
+public final class UnsafeFixedWidthAggregationMap {
+
+ /**
+ * An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the
+ * map, we copy this buffer and use it as the value.
+ */
+ private final long[] emptyAggregationBuffer;
+
+ private final StructType aggregationBufferSchema;
+
+ private final StructType groupingKeySchema;
+
+ /**
+ * Encodes grouping keys as UnsafeRows.
+ */
+ private final UnsafeRowConverter groupingKeyToUnsafeRowConverter;
+
+ /**
+ * A hashmap which maps from opaque bytearray keys to bytearray values.
+ */
+ private final BytesToBytesMap map;
+
+ /**
+ * Re-used pointer to the current aggregation buffer
+ */
+ private final UnsafeRow currentAggregationBuffer = new UnsafeRow();
+
+ /**
+ * Scratch space that is used when encoding grouping keys into UnsafeRow format.
+ *
+ * By default, this is a 1MB array, but it will grow as necessary in case larger keys are
+ * encountered.
+ */
+ private long[] groupingKeyConversionScratchSpace = new long[1024 / 8];
+
+ private final boolean enablePerfMetrics;
+
+ /**
+ * @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema,
+ * false otherwise.
+ */
+ public static boolean supportsGroupKeySchema(StructType schema) {
+ for (StructField field: schema.fields()) {
+ if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /**
+ * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given
+ * schema, false otherwise.
+ */
+ public static boolean supportsAggregationBufferSchema(StructType schema) {
+ for (StructField field: schema.fields()) {
+ if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /**
+ * Create a new UnsafeFixedWidthAggregationMap.
+ *
+ * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function)
+ * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion.
+ * @param groupingKeySchema the schema of the grouping key, used for row conversion.
+ * @param memoryManager the memory manager used to allocate our Unsafe memory structures.
+ * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
+ * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact)
+ */
+ public UnsafeFixedWidthAggregationMap(
+ Row emptyAggregationBuffer,
+ StructType aggregationBufferSchema,
+ StructType groupingKeySchema,
+ TaskMemoryManager memoryManager,
+ int initialCapacity,
+ boolean enablePerfMetrics) {
+ this.emptyAggregationBuffer =
+ convertToUnsafeRow(emptyAggregationBuffer, aggregationBufferSchema);
+ this.aggregationBufferSchema = aggregationBufferSchema;
+ this.groupingKeyToUnsafeRowConverter = new UnsafeRowConverter(groupingKeySchema);
+ this.groupingKeySchema = groupingKeySchema;
+ this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics);
+ this.enablePerfMetrics = enablePerfMetrics;
+ }
+
+ /**
+ * Convert a Java object row into an UnsafeRow, allocating it into a new long array.
+ */
+ private static long[] convertToUnsafeRow(Row javaRow, StructType schema) {
+ final UnsafeRowConverter converter = new UnsafeRowConverter(schema);
+ final long[] unsafeRow = new long[converter.getSizeRequirement(javaRow)];
+ final long writtenLength =
+ converter.writeRow(javaRow, unsafeRow, PlatformDependent.LONG_ARRAY_OFFSET);
+ assert (writtenLength == unsafeRow.length): "Size requirement calculation was wrong!";
+ return unsafeRow;
+ }
+
+ /**
+ * Return the aggregation buffer for the current group. For efficiency, all calls to this method
+ * return the same object.
+ */
+ public UnsafeRow getAggregationBuffer(Row groupingKey) {
+ final int groupingKeySize = groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey);
+ // Make sure that the buffer is large enough to hold the key. If it's not, grow it:
+ if (groupingKeySize > groupingKeyConversionScratchSpace.length) {
+ // This new array will be initially zero, so there's no need to zero it out here
+ groupingKeyConversionScratchSpace = new long[groupingKeySize];
+ } else {
+ // Zero out the buffer that's used to hold the current row. This is necessary in order
+ // to ensure that rows hash properly, since garbage data from the previous row could
+ // otherwise end up as padding in this row. As a performance optimization, we only zero out
+ // the portion of the buffer that we'll actually write to.
+ Arrays.fill(groupingKeyConversionScratchSpace, 0, groupingKeySize, 0);
+ }
+ final long actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow(
+ groupingKey,
+ groupingKeyConversionScratchSpace,
+ PlatformDependent.LONG_ARRAY_OFFSET);
+ assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!";
+
+ // Probe our map using the serialized key
+ final BytesToBytesMap.Location loc = map.lookup(
+ groupingKeyConversionScratchSpace,
+ PlatformDependent.LONG_ARRAY_OFFSET,
+ groupingKeySize);
+ if (!loc.isDefined()) {
+ // This is the first time that we've seen this grouping key, so we'll insert a copy of the
+ // empty aggregation buffer into the map:
+ loc.putNewKey(
+ groupingKeyConversionScratchSpace,
+ PlatformDependent.LONG_ARRAY_OFFSET,
+ groupingKeySize,
+ emptyAggregationBuffer,
+ PlatformDependent.LONG_ARRAY_OFFSET,
+ emptyAggregationBuffer.length
+ );
+ }
+
+ // Reset the pointer to point to the value that we just stored or looked up:
+ final MemoryLocation address = loc.getValueAddress();
+ currentAggregationBuffer.pointTo(
+ address.getBaseObject(),
+ address.getBaseOffset(),
+ aggregationBufferSchema.length(),
+ aggregationBufferSchema
+ );
+ return currentAggregationBuffer;
+ }
+
+ /**
+ * Mutable pair object returned by {@link UnsafeFixedWidthAggregationMap#iterator()}.
+ */
+ public static class MapEntry {
+ private MapEntry() { };
+ public final UnsafeRow key = new UnsafeRow();
+ public final UnsafeRow value = new UnsafeRow();
+ }
+
+ /**
+ * Returns an iterator over the keys and values in this map.
+ *
+ * For efficiency, each call returns the same object.
+ */
+ public Iterator<MapEntry> iterator() {
+ return new Iterator<MapEntry>() {
+
+ private final MapEntry entry = new MapEntry();
+ private final Iterator<BytesToBytesMap.Location> mapLocationIterator = map.iterator();
+
+ @Override
+ public boolean hasNext() {
+ return mapLocationIterator.hasNext();
+ }
+
+ @Override
+ public MapEntry next() {
+ final BytesToBytesMap.Location loc = mapLocationIterator.next();
+ final MemoryLocation keyAddress = loc.getKeyAddress();
+ final MemoryLocation valueAddress = loc.getValueAddress();
+ entry.key.pointTo(
+ keyAddress.getBaseObject(),
+ keyAddress.getBaseOffset(),
+ groupingKeySchema.length(),
+ groupingKeySchema
+ );
+ entry.value.pointTo(
+ valueAddress.getBaseObject(),
+ valueAddress.getBaseOffset(),
+ aggregationBufferSchema.length(),
+ aggregationBufferSchema
+ );
+ return entry;
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+ };
+ }
+
+ /**
+ * Free the unsafe memory associated with this map.
+ */
+ public void free() {
+ map.free();
+ }
+
+ @SuppressWarnings("UseOfSystemOutOrSystemErr")
+ public void printPerfMetrics() {
+ if (!enablePerfMetrics) {
+ throw new IllegalStateException("Perf metrics not enabled");
+ }
+ System.out.println("Average probes per lookup: " + map.getAverageProbesPerLookup());
+ System.out.println("Number of hash collisions: " + map.getNumHashCollisions());
+ System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs());
+ System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption());
+ }
+
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
new file mode 100644
index 0000000000..0a358ed408
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -0,0 +1,435 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions;
+
+import scala.collection.Map;
+import scala.collection.Seq;
+import scala.collection.mutable.ArraySeq;
+
+import javax.annotation.Nullable;
+import java.math.BigDecimal;
+import java.sql.Date;
+import java.util.*;
+
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.types.DataType;
+import static org.apache.spark.sql.types.DataTypes.*;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.types.UTF8String;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.bitset.BitSetMethods;
+
+/**
+ * An Unsafe implementation of Row which is backed by raw memory instead of Java objects.
+ *
+ * Each tuple has three parts: [null bit set] [values] [variable length portion]
+ *
+ * The bit set is used for null tracking and is aligned to 8-byte word boundaries. It stores
+ * one bit per field.
+ *
+ * In the `values` region, we store one 8-byte word per field. For fields that hold fixed-length
+ * primitive types, such as long, double, or int, we store the value directly in the word. For
+ * fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the
+ * base address of the row) that points to the beginning of the variable-length field.
+ *
+ * Instances of `UnsafeRow` act as pointers to row data stored in this format.
+ */
+public final class UnsafeRow implements MutableRow {
+
+ private Object baseObject;
+ private long baseOffset;
+
+ Object getBaseObject() { return baseObject; }
+ long getBaseOffset() { return baseOffset; }
+
+ /** The number of fields in this row, used for calculating the bitset width (and in assertions) */
+ private int numFields;
+
+ /** The width of the null tracking bit set, in bytes */
+ private int bitSetWidthInBytes;
+ /**
+ * This optional schema is required if you want to call generic get() and set() methods on
+ * this UnsafeRow, but is optional if callers will only use type-specific getTYPE() and setTYPE()
+ * methods. This should be removed after the planned InternalRow / Row split; right now, it's only
+ * needed by the generic get() method, which is only called internally by code that accesses
+ * UTF8String-typed columns.
+ */
+ @Nullable
+ private StructType schema;
+
+ private long getFieldOffset(int ordinal) {
+ return baseOffset + bitSetWidthInBytes + ordinal * 8L;
+ }
+
+ public static int calculateBitSetWidthInBytes(int numFields) {
+ return ((numFields / 64) + (numFields % 64 == 0 ? 0 : 1)) * 8;
+ }
+
+ /**
+ * Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types)
+ */
+ public static final Set<DataType> settableFieldTypes;
+
+ /**
+ * Fields types can be read(but not set (e.g. set() will throw UnsupportedOperationException).
+ */
+ public static final Set<DataType> readableFieldTypes;
+
+ static {
+ settableFieldTypes = Collections.unmodifiableSet(
+ new HashSet<DataType>(
+ Arrays.asList(new DataType[] {
+ NullType,
+ BooleanType,
+ ByteType,
+ ShortType,
+ IntegerType,
+ LongType,
+ FloatType,
+ DoubleType
+ })));
+
+ // We support get() on a superset of the types for which we support set():
+ final Set<DataType> _readableFieldTypes = new HashSet<DataType>(
+ Arrays.asList(new DataType[]{
+ StringType
+ }));
+ _readableFieldTypes.addAll(settableFieldTypes);
+ readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes);
+ }
+
+ /**
+ * Construct a new UnsafeRow. The resulting row won't be usable until `pointTo()` has been called,
+ * since the value returned by this constructor is equivalent to a null pointer.
+ */
+ public UnsafeRow() { }
+
+ /**
+ * Update this UnsafeRow to point to different backing data.
+ *
+ * @param baseObject the base object
+ * @param baseOffset the offset within the base object
+ * @param numFields the number of fields in this row
+ * @param schema an optional schema; this is necessary if you want to call generic get() or set()
+ * methods on this row, but is optional if the caller will only use type-specific
+ * getTYPE() and setTYPE() methods.
+ */
+ public void pointTo(
+ Object baseObject,
+ long baseOffset,
+ int numFields,
+ @Nullable StructType schema) {
+ assert numFields >= 0 : "numFields should >= 0";
+ assert schema == null || schema.fields().length == numFields;
+ this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields);
+ this.baseObject = baseObject;
+ this.baseOffset = baseOffset;
+ this.numFields = numFields;
+ this.schema = schema;
+ }
+
+ private void assertIndexIsValid(int index) {
+ assert index >= 0 : "index (" + index + ") should >= 0";
+ assert index < numFields : "index (" + index + ") should <= " + numFields;
+ }
+
+ @Override
+ public void setNullAt(int i) {
+ assertIndexIsValid(i);
+ BitSetMethods.set(baseObject, baseOffset, i);
+ // To preserve row equality, zero out the value when setting the column to null.
+ // Since this row does does not currently support updates to variable-length values, we don't
+ // have to worry about zeroing out that data.
+ PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(i), 0);
+ }
+
+ private void setNotNullAt(int i) {
+ assertIndexIsValid(i);
+ BitSetMethods.unset(baseObject, baseOffset, i);
+ }
+
+ @Override
+ public void update(int ordinal, Object value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setInt(int ordinal, int value) {
+ assertIndexIsValid(ordinal);
+ setNotNullAt(ordinal);
+ PlatformDependent.UNSAFE.putInt(baseObject, getFieldOffset(ordinal), value);
+ }
+
+ @Override
+ public void setLong(int ordinal, long value) {
+ assertIndexIsValid(ordinal);
+ setNotNullAt(ordinal);
+ PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(ordinal), value);
+ }
+
+ @Override
+ public void setDouble(int ordinal, double value) {
+ assertIndexIsValid(ordinal);
+ setNotNullAt(ordinal);
+ PlatformDependent.UNSAFE.putDouble(baseObject, getFieldOffset(ordinal), value);
+ }
+
+ @Override
+ public void setBoolean(int ordinal, boolean value) {
+ assertIndexIsValid(ordinal);
+ setNotNullAt(ordinal);
+ PlatformDependent.UNSAFE.putBoolean(baseObject, getFieldOffset(ordinal), value);
+ }
+
+ @Override
+ public void setShort(int ordinal, short value) {
+ assertIndexIsValid(ordinal);
+ setNotNullAt(ordinal);
+ PlatformDependent.UNSAFE.putShort(baseObject, getFieldOffset(ordinal), value);
+ }
+
+ @Override
+ public void setByte(int ordinal, byte value) {
+ assertIndexIsValid(ordinal);
+ setNotNullAt(ordinal);
+ PlatformDependent.UNSAFE.putByte(baseObject, getFieldOffset(ordinal), value);
+ }
+
+ @Override
+ public void setFloat(int ordinal, float value) {
+ assertIndexIsValid(ordinal);
+ setNotNullAt(ordinal);
+ PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value);
+ }
+
+ @Override
+ public void setString(int ordinal, String value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int size() {
+ return numFields;
+ }
+
+ @Override
+ public int length() {
+ return size();
+ }
+
+ @Override
+ public StructType schema() {
+ return schema;
+ }
+
+ @Override
+ public Object apply(int i) {
+ return get(i);
+ }
+
+ @Override
+ public Object get(int i) {
+ assertIndexIsValid(i);
+ assert (schema != null) : "Schema must be defined when calling generic get() method";
+ final DataType dataType = schema.fields()[i].dataType();
+ // UnsafeRow is only designed to be invoked by internal code, which only invokes this generic
+ // get() method when trying to access UTF8String-typed columns. If we refactor the codebase to
+ // separate the internal and external row interfaces, then internal code can fetch strings via
+ // a new getUTF8String() method and we'll be able to remove this method.
+ if (isNullAt(i)) {
+ return null;
+ } else if (dataType == StringType) {
+ return getUTF8String(i);
+ } else {
+ throw new UnsupportedOperationException();
+ }
+ }
+
+ @Override
+ public boolean isNullAt(int i) {
+ assertIndexIsValid(i);
+ return BitSetMethods.isSet(baseObject, baseOffset, i);
+ }
+
+ @Override
+ public boolean getBoolean(int i) {
+ assertIndexIsValid(i);
+ return PlatformDependent.UNSAFE.getBoolean(baseObject, getFieldOffset(i));
+ }
+
+ @Override
+ public byte getByte(int i) {
+ assertIndexIsValid(i);
+ return PlatformDependent.UNSAFE.getByte(baseObject, getFieldOffset(i));
+ }
+
+ @Override
+ public short getShort(int i) {
+ assertIndexIsValid(i);
+ return PlatformDependent.UNSAFE.getShort(baseObject, getFieldOffset(i));
+ }
+
+ @Override
+ public int getInt(int i) {
+ assertIndexIsValid(i);
+ return PlatformDependent.UNSAFE.getInt(baseObject, getFieldOffset(i));
+ }
+
+ @Override
+ public long getLong(int i) {
+ assertIndexIsValid(i);
+ return PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(i));
+ }
+
+ @Override
+ public float getFloat(int i) {
+ assertIndexIsValid(i);
+ if (isNullAt(i)) {
+ return Float.NaN;
+ } else {
+ return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(i));
+ }
+ }
+
+ @Override
+ public double getDouble(int i) {
+ assertIndexIsValid(i);
+ if (isNullAt(i)) {
+ return Float.NaN;
+ } else {
+ return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(i));
+ }
+ }
+
+ public UTF8String getUTF8String(int i) {
+ assertIndexIsValid(i);
+ final UTF8String str = new UTF8String();
+ final long offsetToStringSize = getLong(i);
+ final int stringSizeInBytes =
+ (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offsetToStringSize);
+ final byte[] strBytes = new byte[stringSizeInBytes];
+ PlatformDependent.copyMemory(
+ baseObject,
+ baseOffset + offsetToStringSize + 8, // The `+ 8` is to skip past the size to get the data
+ strBytes,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ stringSizeInBytes
+ );
+ str.set(strBytes);
+ return str;
+ }
+
+ @Override
+ public String getString(int i) {
+ return getUTF8String(i).toString();
+ }
+
+ @Override
+ public BigDecimal getDecimal(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Date getDate(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public <T> Seq<T> getSeq(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public <T> List<T> getList(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public <K, V> Map<K, V> getMap(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public <T> scala.collection.immutable.Map<String, T> getValuesMap(Seq<String> fieldNames) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public <K, V> java.util.Map<K, V> getJavaMap(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Row getStruct(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public <T> T getAs(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public <T> T getAs(String fieldName) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int fieldIndex(String name) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Row copy() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean anyNull() {
+ return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes);
+ }
+
+ @Override
+ public Seq<Object> toSeq() {
+ final ArraySeq<Object> values = new ArraySeq<Object>(numFields);
+ for (int fieldNumber = 0; fieldNumber < numFields; fieldNumber++) {
+ values.update(fieldNumber, get(fieldNumber));
+ }
+ return values;
+ }
+
+ @Override
+ public String toString() {
+ return mkString("[", ",", "]");
+ }
+
+ @Override
+ public String mkString() {
+ return toSeq().mkString();
+ }
+
+ @Override
+ public String mkString(String sep) {
+ return toSeq().mkString(sep);
+ }
+
+ @Override
+ public String mkString(String start, String sep, String end) {
+ return toSeq().mkString(start, sep, end);
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
new file mode 100644
index 0000000000..5b2c857278
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.PlatformDependent
+import org.apache.spark.unsafe.array.ByteArrayMethods
+
+/**
+ * Converts Rows into UnsafeRow format. This class is NOT thread-safe.
+ *
+ * @param fieldTypes the data types of the row's columns.
+ */
+class UnsafeRowConverter(fieldTypes: Array[DataType]) {
+
+ def this(schema: StructType) {
+ this(schema.fields.map(_.dataType))
+ }
+
+ /** Re-used pointer to the unsafe row being written */
+ private[this] val unsafeRow = new UnsafeRow()
+
+ /** Functions for encoding each column */
+ private[this] val writers: Array[UnsafeColumnWriter] = {
+ fieldTypes.map(t => UnsafeColumnWriter.forType(t))
+ }
+
+ /** The size, in bytes, of the fixed-length portion of the row, including the null bitmap */
+ private[this] val fixedLengthSize: Int =
+ (8 * fieldTypes.length) + UnsafeRow.calculateBitSetWidthInBytes(fieldTypes.length)
+
+ /**
+ * Compute the amount of space, in bytes, required to encode the given row.
+ */
+ def getSizeRequirement(row: Row): Int = {
+ var fieldNumber = 0
+ var variableLengthFieldSize: Int = 0
+ while (fieldNumber < writers.length) {
+ if (!row.isNullAt(fieldNumber)) {
+ variableLengthFieldSize += writers(fieldNumber).getSize(row, fieldNumber)
+ }
+ fieldNumber += 1
+ }
+ fixedLengthSize + variableLengthFieldSize
+ }
+
+ /**
+ * Convert the given row into UnsafeRow format.
+ *
+ * @param row the row to convert
+ * @param baseObject the base object of the destination address
+ * @param baseOffset the base offset of the destination address
+ * @return the number of bytes written. This should be equal to `getSizeRequirement(row)`.
+ */
+ def writeRow(row: Row, baseObject: Object, baseOffset: Long): Long = {
+ unsafeRow.pointTo(baseObject, baseOffset, writers.length, null)
+ var fieldNumber = 0
+ var appendCursor: Int = fixedLengthSize
+ while (fieldNumber < writers.length) {
+ if (row.isNullAt(fieldNumber)) {
+ unsafeRow.setNullAt(fieldNumber)
+ } else {
+ appendCursor += writers(fieldNumber).write(row, unsafeRow, fieldNumber, appendCursor)
+ }
+ fieldNumber += 1
+ }
+ appendCursor
+ }
+
+}
+
+/**
+ * Function for writing a column into an UnsafeRow.
+ */
+private abstract class UnsafeColumnWriter {
+ /**
+ * Write a value into an UnsafeRow.
+ *
+ * @param source the row being converted
+ * @param target a pointer to the converted unsafe row
+ * @param column the column to write
+ * @param appendCursor the offset from the start of the unsafe row to the end of the row;
+ * used for calculating where variable-length data should be written
+ * @return the number of variable-length bytes written
+ */
+ def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int
+
+ /**
+ * Return the number of bytes that are needed to write this variable-length value.
+ */
+ def getSize(source: Row, column: Int): Int
+}
+
+private object UnsafeColumnWriter {
+
+ def forType(dataType: DataType): UnsafeColumnWriter = {
+ dataType match {
+ case NullType => NullUnsafeColumnWriter
+ case BooleanType => BooleanUnsafeColumnWriter
+ case ByteType => ByteUnsafeColumnWriter
+ case ShortType => ShortUnsafeColumnWriter
+ case IntegerType => IntUnsafeColumnWriter
+ case LongType => LongUnsafeColumnWriter
+ case FloatType => FloatUnsafeColumnWriter
+ case DoubleType => DoubleUnsafeColumnWriter
+ case StringType => StringUnsafeColumnWriter
+ case t =>
+ throw new UnsupportedOperationException(s"Do not know how to write columns of type $t")
+ }
+ }
+}
+
+// ------------------------------------------------------------------------------------------------
+
+private object NullUnsafeColumnWriter extends NullUnsafeColumnWriter
+private object BooleanUnsafeColumnWriter extends BooleanUnsafeColumnWriter
+private object ByteUnsafeColumnWriter extends ByteUnsafeColumnWriter
+private object ShortUnsafeColumnWriter extends ShortUnsafeColumnWriter
+private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter
+private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter
+private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter
+private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter
+private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter
+
+private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter {
+ // Primitives don't write to the variable-length region:
+ def getSize(sourceRow: Row, column: Int): Int = 0
+}
+
+private class NullUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
+ override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = {
+ target.setNullAt(column)
+ 0
+ }
+}
+
+private class BooleanUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
+ override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = {
+ target.setBoolean(column, source.getBoolean(column))
+ 0
+ }
+}
+
+private class ByteUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
+ override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = {
+ target.setByte(column, source.getByte(column))
+ 0
+ }
+}
+
+private class ShortUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
+ override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = {
+ target.setShort(column, source.getShort(column))
+ 0
+ }
+}
+
+private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
+ override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = {
+ target.setInt(column, source.getInt(column))
+ 0
+ }
+}
+
+private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
+ override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = {
+ target.setLong(column, source.getLong(column))
+ 0
+ }
+}
+
+private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
+ override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = {
+ target.setFloat(column, source.getFloat(column))
+ 0
+ }
+}
+
+private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
+ override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = {
+ target.setDouble(column, source.getDouble(column))
+ 0
+ }
+}
+
+private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter {
+ def getSize(source: Row, column: Int): Int = {
+ val numBytes = source.get(column).asInstanceOf[UTF8String].getBytes.length
+ 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
+ }
+
+ override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = {
+ val value = source.get(column).asInstanceOf[UTF8String]
+ val baseObject = target.getBaseObject
+ val baseOffset = target.getBaseOffset
+ val numBytes = value.getBytes.length
+ PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes)
+ PlatformDependent.copyMemory(
+ value.getBytes,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ baseObject,
+ baseOffset + appendCursor + 8,
+ numBytes
+ )
+ target.setLong(column, appendCursor)
+ 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
new file mode 100644
index 0000000000..7a19e511eb
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import scala.collection.JavaConverters._
+import scala.util.Random
+
+import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator}
+import org.scalatest.{BeforeAndAfterEach, FunSuite, Matchers}
+
+import org.apache.spark.sql.types._
+
+class UnsafeFixedWidthAggregationMapSuite extends FunSuite with Matchers with BeforeAndAfterEach {
+
+ import UnsafeFixedWidthAggregationMap._
+
+ private val groupKeySchema = StructType(StructField("product", StringType) :: Nil)
+ private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil)
+ private def emptyAggregationBuffer: Row = new GenericRow(Array[Any](0))
+
+ private var memoryManager: TaskMemoryManager = null
+
+ override def beforeEach(): Unit = {
+ memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+ }
+
+ override def afterEach(): Unit = {
+ if (memoryManager != null) {
+ memoryManager.cleanUpAllAllocatedMemory()
+ memoryManager = null
+ }
+ }
+
+ test("supported schemas") {
+ assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil)))
+ assert(supportsGroupKeySchema(StructType(StructField("x", StringType) :: Nil)))
+
+ assert(
+ !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil)))
+ assert(
+ !supportsGroupKeySchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil)))
+ }
+
+ test("empty map") {
+ val map = new UnsafeFixedWidthAggregationMap(
+ emptyAggregationBuffer,
+ aggBufferSchema,
+ groupKeySchema,
+ memoryManager,
+ 1024, // initial capacity
+ false // disable perf metrics
+ )
+ assert(!map.iterator().hasNext)
+ map.free()
+ }
+
+ test("updating values for a single key") {
+ val map = new UnsafeFixedWidthAggregationMap(
+ emptyAggregationBuffer,
+ aggBufferSchema,
+ groupKeySchema,
+ memoryManager,
+ 1024, // initial capacity
+ false // disable perf metrics
+ )
+ val groupKey = new GenericRow(Array[Any](UTF8String("cats")))
+
+ // Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts)
+ map.getAggregationBuffer(groupKey)
+ val iter = map.iterator()
+ val entry = iter.next()
+ assert(!iter.hasNext)
+ entry.key.getString(0) should be ("cats")
+ entry.value.getInt(0) should be (0)
+
+ // Modifications to rows retrieved from the map should update the values in the map
+ entry.value.setInt(0, 42)
+ map.getAggregationBuffer(groupKey).getInt(0) should be (42)
+
+ map.free()
+ }
+
+ test("inserting large random keys") {
+ val map = new UnsafeFixedWidthAggregationMap(
+ emptyAggregationBuffer,
+ aggBufferSchema,
+ groupKeySchema,
+ memoryManager,
+ 128, // initial capacity
+ false // disable perf metrics
+ )
+ val rand = new Random(42)
+ val groupKeys: Set[String] = Seq.fill(512)(rand.nextString(1024)).toSet
+ groupKeys.foreach { keyString =>
+ map.getAggregationBuffer(new GenericRow(Array[Any](UTF8String(keyString))))
+ }
+ val seenKeys: Set[String] = map.iterator().asScala.map { entry =>
+ entry.key.getString(0)
+ }.toSet
+ seenKeys.size should be (groupKeys.size)
+ seenKeys should be (groupKeys)
+ }
+
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
new file mode 100644
index 0000000000..3a60c7fd32
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import java.util.Arrays
+
+import org.scalatest.{FunSuite, Matchers}
+
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.PlatformDependent
+import org.apache.spark.unsafe.array.ByteArrayMethods
+
+class UnsafeRowConverterSuite extends FunSuite with Matchers {
+
+ test("basic conversion with only primitive types") {
+ val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType)
+ val converter = new UnsafeRowConverter(fieldTypes)
+
+ val row = new SpecificMutableRow(fieldTypes)
+ row.setLong(0, 0)
+ row.setLong(1, 1)
+ row.setInt(2, 2)
+
+ val sizeRequired: Int = converter.getSizeRequirement(row)
+ sizeRequired should be (8 + (3 * 8))
+ val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
+ val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
+ numBytesWritten should be (sizeRequired)
+
+ val unsafeRow = new UnsafeRow()
+ unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
+ unsafeRow.getLong(0) should be (0)
+ unsafeRow.getLong(1) should be (1)
+ unsafeRow.getInt(2) should be (2)
+ }
+
+ test("basic conversion with primitive and string types") {
+ val fieldTypes: Array[DataType] = Array(LongType, StringType, StringType)
+ val converter = new UnsafeRowConverter(fieldTypes)
+
+ val row = new SpecificMutableRow(fieldTypes)
+ row.setLong(0, 0)
+ row.setString(1, "Hello")
+ row.setString(2, "World")
+
+ val sizeRequired: Int = converter.getSizeRequirement(row)
+ sizeRequired should be (8 + (8 * 3) +
+ ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8) +
+ ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length + 8))
+ val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
+ val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
+ numBytesWritten should be (sizeRequired)
+
+ val unsafeRow = new UnsafeRow()
+ unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
+ unsafeRow.getLong(0) should be (0)
+ unsafeRow.getString(1) should be ("Hello")
+ unsafeRow.getString(2) should be ("World")
+ }
+
+ test("null handling") {
+ val fieldTypes: Array[DataType] = Array(
+ NullType,
+ BooleanType,
+ ByteType,
+ ShortType,
+ IntegerType,
+ LongType,
+ FloatType,
+ DoubleType)
+ val converter = new UnsafeRowConverter(fieldTypes)
+
+ val rowWithAllNullColumns: Row = {
+ val r = new SpecificMutableRow(fieldTypes)
+ for (i <- 0 to fieldTypes.length - 1) {
+ r.setNullAt(i)
+ }
+ r
+ }
+
+ val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns)
+ val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8)
+ val numBytesWritten = converter.writeRow(
+ rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET)
+ numBytesWritten should be (sizeRequired)
+
+ val createdFromNull = new UnsafeRow()
+ createdFromNull.pointTo(
+ createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
+ for (i <- 0 to fieldTypes.length - 1) {
+ assert(createdFromNull.isNullAt(i))
+ }
+ createdFromNull.getBoolean(1) should be (false)
+ createdFromNull.getByte(2) should be (0)
+ createdFromNull.getShort(3) should be (0)
+ createdFromNull.getInt(4) should be (0)
+ createdFromNull.getLong(5) should be (0)
+ assert(java.lang.Float.isNaN(createdFromNull.getFloat(6)))
+ assert(java.lang.Double.isNaN(createdFromNull.getFloat(7)))
+
+ // If we have an UnsafeRow with columns that are initially non-null and we null out those
+ // columns, then the serialized row representation should be identical to what we would get by
+ // creating an entirely null row via the converter
+ val rowWithNoNullColumns: Row = {
+ val r = new SpecificMutableRow(fieldTypes)
+ r.setNullAt(0)
+ r.setBoolean(1, false)
+ r.setByte(2, 20)
+ r.setShort(3, 30)
+ r.setInt(4, 400)
+ r.setLong(5, 500)
+ r.setFloat(6, 600)
+ r.setDouble(7, 700)
+ r
+ }
+ val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8)
+ converter.writeRow(
+ rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET)
+ val setToNullAfterCreation = new UnsafeRow()
+ setToNullAfterCreation.pointTo(
+ setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
+
+ setToNullAfterCreation.isNullAt(0) should be (rowWithNoNullColumns.isNullAt(0))
+ setToNullAfterCreation.getBoolean(1) should be (rowWithNoNullColumns.getBoolean(1))
+ setToNullAfterCreation.getByte(2) should be (rowWithNoNullColumns.getByte(2))
+ setToNullAfterCreation.getShort(3) should be (rowWithNoNullColumns.getShort(3))
+ setToNullAfterCreation.getInt(4) should be (rowWithNoNullColumns.getInt(4))
+ setToNullAfterCreation.getLong(5) should be (rowWithNoNullColumns.getLong(5))
+ setToNullAfterCreation.getFloat(6) should be (rowWithNoNullColumns.getFloat(6))
+ setToNullAfterCreation.getDouble(7) should be (rowWithNoNullColumns.getDouble(7))
+
+ for (i <- 0 to fieldTypes.length - 1) {
+ setToNullAfterCreation.setNullAt(i)
+ }
+ assert(Arrays.equals(createdFromNullBuffer, setToNullAfterCreationBuffer))
+ }
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 4fc5de7e82..2fa602a608 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -30,6 +30,7 @@ private[spark] object SQLConf {
val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes"
val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions"
val CODEGEN_ENABLED = "spark.sql.codegen"
+ val UNSAFE_ENABLED = "spark.sql.unsafe.enabled"
val DIALECT = "spark.sql.dialect"
val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString"
@@ -149,6 +150,14 @@ private[sql] class SQLConf extends Serializable {
*/
private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "false").toBoolean
+ /**
+ * When set to true, Spark SQL will use managed memory for certain operations. This option only
+ * takes effect if codegen is enabled.
+ *
+ * Defaults to false as this feature is currently experimental.
+ */
+ private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, "false").toBoolean
+
private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, "true").toBoolean
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index a279b0f07c..bd4a55fa13 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -1011,6 +1011,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
def codegenEnabled: Boolean = self.conf.codegenEnabled
+ def unsafeEnabled: Boolean = self.conf.unsafeEnabled
+
def numPartitions: Int = self.conf.numShufflePartitions
def strategies: Seq[Strategy] =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index b1ef6556de..5d9f202681 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution
+import org.apache.spark.TaskContext
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.trees._
@@ -40,6 +41,7 @@ case class AggregateEvaluation(
* ensure all values where `groupingExpressions` are equal are present.
* @param groupingExpressions expressions that are evaluated to determine grouping.
* @param aggregateExpressions expressions that are computed for each group.
+ * @param unsafeEnabled whether to allow Unsafe-based aggregation buffers to be used.
* @param child the input data source.
*/
@DeveloperApi
@@ -47,6 +49,7 @@ case class GeneratedAggregate(
partial: Boolean,
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
+ unsafeEnabled: Boolean,
child: SparkPlan)
extends UnaryNode {
@@ -225,6 +228,21 @@ case class GeneratedAggregate(
case e: Expression if groupMap.contains(e) => groupMap(e)
})
+ val aggregationBufferSchema: StructType = StructType.fromAttributes(computationSchema)
+
+ val groupKeySchema: StructType = {
+ val fields = groupingExpressions.zipWithIndex.map { case (expr, idx) =>
+ // This is a dummy field name
+ StructField(idx.toString, expr.dataType, expr.nullable)
+ }
+ StructType(fields)
+ }
+
+ val schemaSupportsUnsafe: Boolean = {
+ UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) &&
+ UnsafeFixedWidthAggregationMap.supportsGroupKeySchema(groupKeySchema)
+ }
+
child.execute().mapPartitions { iter =>
// Builds a new custom class for holding the results of aggregation for a group.
val initialValues = computeFunctions.flatMap(_.initialValues)
@@ -265,7 +283,49 @@ case class GeneratedAggregate(
val resultProjection = resultProjectionBuilder()
Iterator(resultProjection(buffer))
+ } else if (unsafeEnabled && schemaSupportsUnsafe) {
+ log.info("Using Unsafe-based aggregator")
+ val aggregationMap = new UnsafeFixedWidthAggregationMap(
+ newAggregationBuffer(EmptyRow),
+ aggregationBufferSchema,
+ groupKeySchema,
+ TaskContext.get.taskMemoryManager(),
+ 1024 * 16, // initial capacity
+ false // disable tracking of performance metrics
+ )
+
+ while (iter.hasNext) {
+ val currentRow: Row = iter.next()
+ val groupKey: Row = groupProjection(currentRow)
+ val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey)
+ updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow))
+ }
+
+ new Iterator[Row] {
+ private[this] val mapIterator = aggregationMap.iterator()
+ private[this] val resultProjection = resultProjectionBuilder()
+
+ def hasNext: Boolean = mapIterator.hasNext
+
+ def next(): Row = {
+ val entry = mapIterator.next()
+ val result = resultProjection(joinedRow(entry.key, entry.value))
+ if (hasNext) {
+ result
+ } else {
+ // This is the last element in the iterator, so let's free the buffer. Before we do,
+ // though, we need to make a defensive copy of the result so that we don't return an
+ // object that might contain dangling pointers to the freed memory
+ val resultCopy = result.copy()
+ aggregationMap.free()
+ resultCopy
+ }
+ }
+ }
} else {
+ if (unsafeEnabled) {
+ log.info("Not using Unsafe-based aggregator because it is not supported for this schema")
+ }
val buffers = new java.util.HashMap[Row, MutableRow]()
var currentRow: Row = null
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 3a0a6c8670..af58911cc0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -136,10 +136,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
partial = false,
namedGroupingAttributes,
rewrittenAggregateExpressions,
+ unsafeEnabled,
execution.GeneratedAggregate(
partial = true,
groupingExpressions,
partialComputation,
+ unsafeEnabled,
planLater(child))) :: Nil
// Cases where some aggregate can not be codegened
diff --git a/unsafe/pom.xml b/unsafe/pom.xml
new file mode 100644
index 0000000000..8901d77591
--- /dev/null
+++ b/unsafe/pom.xml
@@ -0,0 +1,69 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-parent_2.10</artifactId>
+ <version>1.4.0-SNAPSHOT</version>
+ <relativePath>../pom.xml</relativePath>
+ </parent>
+
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-unsafe_2.10</artifactId>
+ <packaging>jar</packaging>
+ <name>Spark Project Unsafe</name>
+ <url>http://spark.apache.org/</url>
+ <properties>
+ <sbt.project.name>unsafe</sbt.project.name>
+ </properties>
+
+ <dependencies>
+
+ <!-- Core dependencies -->
+ <dependency>
+ <groupId>com.google.code.findbugs</groupId>
+ <artifactId>jsr305</artifactId>
+ </dependency>
+
+ <!-- Provided dependencies -->
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-api</artifactId>
+ <scope>provided</scope>
+ </dependency>
+
+ <!-- Test dependencies -->
+ <dependency>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.novocode</groupId>
+ <artifactId>junit-interface</artifactId>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+ <build>
+ <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
+ <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
+ </build>
+</project>
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java
new file mode 100644
index 0000000000..91b2f9aa43
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe;
+
+import java.lang.reflect.Field;
+
+import sun.misc.Unsafe;
+
+public final class PlatformDependent {
+
+ public static final Unsafe UNSAFE;
+
+ public static final int BYTE_ARRAY_OFFSET;
+
+ public static final int INT_ARRAY_OFFSET;
+
+ public static final int LONG_ARRAY_OFFSET;
+
+ public static final int DOUBLE_ARRAY_OFFSET;
+
+ /**
+ * Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to
+ * allow safepoint polling during a large copy.
+ */
+ private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L;
+
+ static {
+ sun.misc.Unsafe unsafe;
+ try {
+ Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe");
+ unsafeField.setAccessible(true);
+ unsafe = (sun.misc.Unsafe) unsafeField.get(null);
+ } catch (Throwable cause) {
+ unsafe = null;
+ }
+ UNSAFE = unsafe;
+
+ if (UNSAFE != null) {
+ BYTE_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(byte[].class);
+ INT_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(int[].class);
+ LONG_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(long[].class);
+ DOUBLE_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(double[].class);
+ } else {
+ BYTE_ARRAY_OFFSET = 0;
+ INT_ARRAY_OFFSET = 0;
+ LONG_ARRAY_OFFSET = 0;
+ DOUBLE_ARRAY_OFFSET = 0;
+ }
+ }
+
+ static public void copyMemory(
+ Object src,
+ long srcOffset,
+ Object dst,
+ long dstOffset,
+ long length) {
+ while (length > 0) {
+ long size = Math.min(length, UNSAFE_COPY_THRESHOLD);
+ UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size);
+ length -= size;
+ srcOffset += size;
+ dstOffset += size;
+ }
+ }
+
+ /**
+ * Raises an exception bypassing compiler checks for checked exceptions.
+ */
+ public static void throwException(Throwable t) {
+ UNSAFE.throwException(t);
+ }
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java
new file mode 100644
index 0000000000..53eadf96a6
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.array;
+
+import org.apache.spark.unsafe.PlatformDependent;
+
+public class ByteArrayMethods {
+
+ private ByteArrayMethods() {
+ // Private constructor, since this class only contains static methods.
+ }
+
+ public static int roundNumberOfBytesToNearestWord(int numBytes) {
+ int remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8`
+ if (remainder == 0) {
+ return numBytes;
+ } else {
+ return numBytes + (8 - remainder);
+ }
+ }
+
+ /**
+ * Optimized byte array equality check for 8-byte-word-aligned byte arrays.
+ * @return true if the arrays are equal, false otherwise
+ */
+ public static boolean wordAlignedArrayEquals(
+ Object leftBaseObject,
+ long leftBaseOffset,
+ Object rightBaseObject,
+ long rightBaseOffset,
+ long arrayLengthInBytes) {
+ for (int i = 0; i < arrayLengthInBytes; i += 8) {
+ final long left =
+ PlatformDependent.UNSAFE.getLong(leftBaseObject, leftBaseOffset + i);
+ final long right =
+ PlatformDependent.UNSAFE.getLong(rightBaseObject, rightBaseOffset + i);
+ if (left != right) return false;
+ }
+ return true;
+ }
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java
new file mode 100644
index 0000000000..18d1f0d2d7
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.array;
+
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+
+/**
+ * An array of long values. Compared with native JVM arrays, this:
+ * <ul>
+ * <li>supports using both in-heap and off-heap memory</li>
+ * <li>has no bound checking, and thus can crash the JVM process when assert is turned off</li>
+ * </ul>
+ */
+public final class LongArray {
+
+ // This is a long so that we perform long multiplications when computing offsets.
+ private static final long WIDTH = 8;
+
+ private final MemoryBlock memory;
+ private final Object baseObj;
+ private final long baseOffset;
+
+ private final long length;
+
+ public LongArray(MemoryBlock memory) {
+ assert memory.size() % WIDTH == 0 : "Memory not aligned (" + memory.size() + ")";
+ assert memory.size() < (long) Integer.MAX_VALUE * 8: "Array size > 4 billion elements";
+ this.memory = memory;
+ this.baseObj = memory.getBaseObject();
+ this.baseOffset = memory.getBaseOffset();
+ this.length = memory.size() / WIDTH;
+ }
+
+ public MemoryBlock memoryBlock() {
+ return memory;
+ }
+
+ /**
+ * Returns the number of elements this array can hold.
+ */
+ public long size() {
+ return length;
+ }
+
+ /**
+ * Sets the value at position {@code index}.
+ */
+ public void set(int index, long value) {
+ assert index >= 0 : "index (" + index + ") should >= 0";
+ assert index < length : "index (" + index + ") should < length (" + length + ")";
+ PlatformDependent.UNSAFE.putLong(baseObj, baseOffset + index * WIDTH, value);
+ }
+
+ /**
+ * Returns the value at position {@code index}.
+ */
+ public long get(int index) {
+ assert index >= 0 : "index (" + index + ") should >= 0";
+ assert index < length : "index (" + index + ") should < length (" + length + ")";
+ return PlatformDependent.UNSAFE.getLong(baseObj, baseOffset + index * WIDTH);
+ }
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java
new file mode 100644
index 0000000000..f72e07fce9
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.bitset;
+
+import org.apache.spark.unsafe.array.LongArray;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+
+/**
+ * A fixed size uncompressed bit set backed by a {@link LongArray}.
+ *
+ * Each bit occupies exactly one bit of storage.
+ */
+public final class BitSet {
+
+ /** A long array for the bits. */
+ private final LongArray words;
+
+ /** Length of the long array. */
+ private final int numWords;
+
+ private final Object baseObject;
+ private final long baseOffset;
+
+ /**
+ * Creates a new {@link BitSet} using the specified memory block. Size of the memory block must be
+ * multiple of 8 bytes (i.e. 64 bits).
+ */
+ public BitSet(MemoryBlock memory) {
+ words = new LongArray(memory);
+ assert (words.size() <= Integer.MAX_VALUE);
+ numWords = (int) words.size();
+ baseObject = words.memoryBlock().getBaseObject();
+ baseOffset = words.memoryBlock().getBaseOffset();
+ }
+
+ public MemoryBlock memoryBlock() {
+ return words.memoryBlock();
+ }
+
+ /**
+ * Returns the number of bits in this {@code BitSet}.
+ */
+ public long capacity() {
+ return numWords * 64;
+ }
+
+ /**
+ * Sets the bit at the specified index to {@code true}.
+ */
+ public void set(int index) {
+ assert index < numWords * 64 : "index (" + index + ") should < length (" + numWords * 64 + ")";
+ BitSetMethods.set(baseObject, baseOffset, index);
+ }
+
+ /**
+ * Sets the bit at the specified index to {@code false}.
+ */
+ public void unset(int index) {
+ assert index < numWords * 64 : "index (" + index + ") should < length (" + numWords * 64 + ")";
+ BitSetMethods.unset(baseObject, baseOffset, index);
+ }
+
+ /**
+ * Returns {@code true} if the bit is set at the specified index.
+ */
+ public boolean isSet(int index) {
+ assert index < numWords * 64 : "index (" + index + ") should < length (" + numWords * 64 + ")";
+ return BitSetMethods.isSet(baseObject, baseOffset, index);
+ }
+
+ /**
+ * Returns the index of the first bit that is set to true that occurs on or after the
+ * specified starting index. If no such bit exists then {@code -1} is returned.
+ * <p>
+ * To iterate over the true bits in a BitSet, use the following loop:
+ * <pre>
+ * <code>
+ * for (long i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i + 1)) {
+ * // operate on index i here
+ * }
+ * </code>
+ * </pre>
+ *
+ * @param fromIndex the index to start checking from (inclusive)
+ * @return the index of the next set bit, or -1 if there is no such bit
+ */
+ public int nextSetBit(int fromIndex) {
+ return BitSetMethods.nextSetBit(baseObject, baseOffset, fromIndex, numWords);
+ }
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java
new file mode 100644
index 0000000000..f30626d8f4
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.bitset;
+
+import org.apache.spark.unsafe.PlatformDependent;
+
+/**
+ * Methods for working with fixed-size uncompressed bitsets.
+ *
+ * We assume that the bitset data is word-aligned (that is, a multiple of 8 bytes in length).
+ *
+ * Each bit occupies exactly one bit of storage.
+ */
+public final class BitSetMethods {
+
+ private static final long WORD_SIZE = 8;
+
+ private BitSetMethods() {
+ // Make the default constructor private, since this only holds static methods.
+ }
+
+ /**
+ * Sets the bit at the specified index to {@code true}.
+ */
+ public static void set(Object baseObject, long baseOffset, int index) {
+ assert index >= 0 : "index (" + index + ") should >= 0";
+ final long mask = 1L << (index & 0x3f); // mod 64 and shift
+ final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE;
+ final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset);
+ PlatformDependent.UNSAFE.putLong(baseObject, wordOffset, word | mask);
+ }
+
+ /**
+ * Sets the bit at the specified index to {@code false}.
+ */
+ public static void unset(Object baseObject, long baseOffset, int index) {
+ assert index >= 0 : "index (" + index + ") should >= 0";
+ final long mask = 1L << (index & 0x3f); // mod 64 and shift
+ final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE;
+ final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset);
+ PlatformDependent.UNSAFE.putLong(baseObject, wordOffset, word & ~mask);
+ }
+
+ /**
+ * Returns {@code true} if the bit is set at the specified index.
+ */
+ public static boolean isSet(Object baseObject, long baseOffset, int index) {
+ assert index >= 0 : "index (" + index + ") should >= 0";
+ final long mask = 1L << (index & 0x3f); // mod 64 and shift
+ final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE;
+ final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset);
+ return (word & mask) != 0;
+ }
+
+ /**
+ * Returns {@code true} if any bit is set.
+ */
+ public static boolean anySet(Object baseObject, long baseOffset, long bitSetWidthInBytes) {
+ for (int i = 0; i <= bitSetWidthInBytes; i++) {
+ if (PlatformDependent.UNSAFE.getByte(baseObject, baseOffset + i) != 0) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ /**
+ * Returns the index of the first bit that is set to true that occurs on or after the
+ * specified starting index. If no such bit exists then {@code -1} is returned.
+ * <p>
+ * To iterate over the true bits in a BitSet, use the following loop:
+ * <pre>
+ * <code>
+ * for (long i = bs.nextSetBit(0, sizeInWords); i >= 0; i = bs.nextSetBit(i + 1, sizeInWords)) {
+ * // operate on index i here
+ * }
+ * </code>
+ * </pre>
+ *
+ * @param fromIndex the index to start checking from (inclusive)
+ * @param bitsetSizeInWords the size of the bitset, measured in 8-byte words
+ * @return the index of the next set bit, or -1 if there is no such bit
+ */
+ public static int nextSetBit(
+ Object baseObject,
+ long baseOffset,
+ int fromIndex,
+ int bitsetSizeInWords) {
+ int wi = fromIndex >> 6;
+ if (wi >= bitsetSizeInWords) {
+ return -1;
+ }
+
+ // Try to find the next set bit in the current word
+ final int subIndex = fromIndex & 0x3f;
+ long word =
+ PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + wi * WORD_SIZE) >> subIndex;
+ if (word != 0) {
+ return (wi << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word);
+ }
+
+ // Find the next set bit in the rest of the words
+ wi += 1;
+ while (wi < bitsetSizeInWords) {
+ word = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + wi * WORD_SIZE);
+ if (word != 0) {
+ return (wi << 6) + java.lang.Long.numberOfTrailingZeros(word);
+ }
+ wi += 1;
+ }
+
+ return -1;
+ }
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
new file mode 100644
index 0000000000..85cd02469a
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.hash;
+
+import org.apache.spark.unsafe.PlatformDependent;
+
+/**
+ * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction.
+ */
+public final class Murmur3_x86_32 {
+ private static final int C1 = 0xcc9e2d51;
+ private static final int C2 = 0x1b873593;
+
+ private final int seed;
+
+ public Murmur3_x86_32(int seed) {
+ this.seed = seed;
+ }
+
+ @Override
+ public String toString() {
+ return "Murmur3_32(seed=" + seed + ")";
+ }
+
+ public int hashInt(int input) {
+ int k1 = mixK1(input);
+ int h1 = mixH1(seed, k1);
+
+ return fmix(h1, 4);
+ }
+
+ public int hashUnsafeWords(Object baseObject, long baseOffset, int lengthInBytes) {
+ // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method.
+ assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)";
+ int h1 = seed;
+ for (int offset = 0; offset < lengthInBytes; offset += 4) {
+ int halfWord = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset);
+ int k1 = mixK1(halfWord);
+ h1 = mixH1(h1, k1);
+ }
+ return fmix(h1, lengthInBytes);
+ }
+
+ public int hashLong(long input) {
+ int low = (int) input;
+ int high = (int) (input >>> 32);
+
+ int k1 = mixK1(low);
+ int h1 = mixH1(seed, k1);
+
+ k1 = mixK1(high);
+ h1 = mixH1(h1, k1);
+
+ return fmix(h1, 8);
+ }
+
+ private static int mixK1(int k1) {
+ k1 *= C1;
+ k1 = Integer.rotateLeft(k1, 15);
+ k1 *= C2;
+ return k1;
+ }
+
+ private static int mixH1(int h1, int k1) {
+ h1 ^= k1;
+ h1 = Integer.rotateLeft(h1, 13);
+ h1 = h1 * 5 + 0xe6546b64;
+ return h1;
+ }
+
+ // Finalization mix - force all bits of a hash block to avalanche
+ private static int fmix(int h1, int length) {
+ h1 ^= length;
+ h1 ^= h1 >>> 16;
+ h1 *= 0x85ebca6b;
+ h1 ^= h1 >>> 13;
+ h1 *= 0xc2b2ae35;
+ h1 ^= h1 >>> 16;
+ return h1;
+ }
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
new file mode 100644
index 0000000000..85b64c0833
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -0,0 +1,549 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.map;
+
+import java.lang.Override;
+import java.lang.UnsupportedOperationException;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+
+import org.apache.spark.unsafe.*;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.array.LongArray;
+import org.apache.spark.unsafe.bitset.BitSet;
+import org.apache.spark.unsafe.hash.Murmur3_x86_32;
+import org.apache.spark.unsafe.memory.*;
+
+/**
+ * An append-only hash map where keys and values are contiguous regions of bytes.
+ * <p>
+ * This is backed by a power-of-2-sized hash table, using quadratic probing with triangular numbers,
+ * which is guaranteed to exhaust the space.
+ * <p>
+ * The map can support up to 2^31 keys because we use 32 bit MurmurHash. If the key cardinality is
+ * higher than this, you should probably be using sorting instead of hashing for better cache
+ * locality.
+ * <p>
+ * This class is not thread safe.
+ */
+public final class BytesToBytesMap {
+
+ private static final Murmur3_x86_32 HASHER = new Murmur3_x86_32(0);
+
+ private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING;
+
+ private final TaskMemoryManager memoryManager;
+
+ /**
+ * A linked list for tracking all allocated data pages so that we can free all of our memory.
+ */
+ private final List<MemoryBlock> dataPages = new LinkedList<MemoryBlock>();
+
+ /**
+ * The data page that will be used to store keys and values for new hashtable entries. When this
+ * page becomes full, a new page will be allocated and this pointer will change to point to that
+ * new page.
+ */
+ private MemoryBlock currentDataPage = null;
+
+ /**
+ * Offset into `currentDataPage` that points to the location where new data can be inserted into
+ * the page.
+ */
+ private long pageCursor = 0;
+
+ /**
+ * The size of the data pages that hold key and value data. Map entries cannot span multiple
+ * pages, so this limits the maximum entry size.
+ */
+ private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes
+
+ // This choice of page table size and page size means that we can address up to 500 gigabytes
+ // of memory.
+
+ /**
+ * A single array to store the key and value.
+ *
+ * Position {@code 2 * i} in the array is used to track a pointer to the key at index {@code i},
+ * while position {@code 2 * i + 1} in the array holds key's full 32-bit hashcode.
+ */
+ private LongArray longArray;
+ // TODO: we're wasting 32 bits of space here; we can probably store fewer bits of the hashcode
+ // and exploit word-alignment to use fewer bits to hold the address. This might let us store
+ // only one long per map entry, increasing the chance that this array will fit in cache at the
+ // expense of maybe performing more lookups if we have hash collisions. Say that we stored only
+ // 27 bits of the hashcode and 37 bits of the address. 37 bits is enough to address 1 terabyte
+ // of RAM given word-alignment. If we use 13 bits of this for our page table, that gives us a
+ // maximum page size of 2^24 * 8 = ~134 megabytes per page. This change will require us to store
+ // full base addresses in the page table for off-heap mode so that we can reconstruct the full
+ // absolute memory addresses.
+
+ /**
+ * A {@link BitSet} used to track location of the map where the key is set.
+ * Size of the bitset should be half of the size of the long array.
+ */
+ private BitSet bitset;
+
+ private final double loadFactor;
+
+ /**
+ * Number of keys defined in the map.
+ */
+ private int size;
+
+ /**
+ * The map will be expanded once the number of keys exceeds this threshold.
+ */
+ private int growthThreshold;
+
+ /**
+ * Mask for truncating hashcodes so that they do not exceed the long array's size.
+ * This is a strength reduction optimization; we're essentially performing a modulus operation,
+ * but doing so with a bitmask because this is a power-of-2-sized hash map.
+ */
+ private int mask;
+
+ /**
+ * Return value of {@link BytesToBytesMap#lookup(Object, long, int)}.
+ */
+ private final Location loc;
+
+ private final boolean enablePerfMetrics;
+
+ private long timeSpentResizingNs = 0;
+
+ private long numProbes = 0;
+
+ private long numKeyLookups = 0;
+
+ private long numHashCollisions = 0;
+
+ public BytesToBytesMap(
+ TaskMemoryManager memoryManager,
+ int initialCapacity,
+ double loadFactor,
+ boolean enablePerfMetrics) {
+ this.memoryManager = memoryManager;
+ this.loadFactor = loadFactor;
+ this.loc = new Location();
+ this.enablePerfMetrics = enablePerfMetrics;
+ allocate(initialCapacity);
+ }
+
+ public BytesToBytesMap(TaskMemoryManager memoryManager, int initialCapacity) {
+ this(memoryManager, initialCapacity, 0.70, false);
+ }
+
+ public BytesToBytesMap(
+ TaskMemoryManager memoryManager,
+ int initialCapacity,
+ boolean enablePerfMetrics) {
+ this(memoryManager, initialCapacity, 0.70, enablePerfMetrics);
+ }
+
+ /**
+ * Returns the number of keys defined in the map.
+ */
+ public int size() { return size; }
+
+ /**
+ * Returns an iterator for iterating over the entries of this map.
+ *
+ * For efficiency, all calls to `next()` will return the same {@link Location} object.
+ *
+ * If any other lookups or operations are performed on this map while iterating over it, including
+ * `lookup()`, the behavior of the returned iterator is undefined.
+ */
+ public Iterator<Location> iterator() {
+ return new Iterator<Location>() {
+
+ private int nextPos = bitset.nextSetBit(0);
+
+ @Override
+ public boolean hasNext() {
+ return nextPos != -1;
+ }
+
+ @Override
+ public Location next() {
+ final int pos = nextPos;
+ nextPos = bitset.nextSetBit(nextPos + 1);
+ return loc.with(pos, 0, true);
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+ };
+ }
+
+ /**
+ * Looks up a key, and return a {@link Location} handle that can be used to test existence
+ * and read/write values.
+ *
+ * This function always return the same {@link Location} instance to avoid object allocation.
+ */
+ public Location lookup(
+ Object keyBaseObject,
+ long keyBaseOffset,
+ int keyRowLengthBytes) {
+ if (enablePerfMetrics) {
+ numKeyLookups++;
+ }
+ final int hashcode = HASHER.hashUnsafeWords(keyBaseObject, keyBaseOffset, keyRowLengthBytes);
+ int pos = hashcode & mask;
+ int step = 1;
+ while (true) {
+ if (enablePerfMetrics) {
+ numProbes++;
+ }
+ if (!bitset.isSet(pos)) {
+ // This is a new key.
+ return loc.with(pos, hashcode, false);
+ } else {
+ long stored = longArray.get(pos * 2 + 1);
+ if ((int) (stored) == hashcode) {
+ // Full hash code matches. Let's compare the keys for equality.
+ loc.with(pos, hashcode, true);
+ if (loc.getKeyLength() == keyRowLengthBytes) {
+ final MemoryLocation keyAddress = loc.getKeyAddress();
+ final Object storedKeyBaseObject = keyAddress.getBaseObject();
+ final long storedKeyBaseOffset = keyAddress.getBaseOffset();
+ final boolean areEqual = ByteArrayMethods.wordAlignedArrayEquals(
+ keyBaseObject,
+ keyBaseOffset,
+ storedKeyBaseObject,
+ storedKeyBaseOffset,
+ keyRowLengthBytes
+ );
+ if (areEqual) {
+ return loc;
+ } else {
+ if (enablePerfMetrics) {
+ numHashCollisions++;
+ }
+ }
+ }
+ }
+ }
+ pos = (pos + step) & mask;
+ step++;
+ }
+ }
+
+ /**
+ * Handle returned by {@link BytesToBytesMap#lookup(Object, long, int)} function.
+ */
+ public final class Location {
+ /** An index into the hash map's Long array */
+ private int pos;
+ /** True if this location points to a position where a key is defined, false otherwise */
+ private boolean isDefined;
+ /**
+ * The hashcode of the most recent key passed to
+ * {@link BytesToBytesMap#lookup(Object, long, int)}. Caching this hashcode here allows us to
+ * avoid re-hashing the key when storing a value for that key.
+ */
+ private int keyHashcode;
+ private final MemoryLocation keyMemoryLocation = new MemoryLocation();
+ private final MemoryLocation valueMemoryLocation = new MemoryLocation();
+ private int keyLength;
+ private int valueLength;
+
+ private void updateAddressesAndSizes(long fullKeyAddress) {
+ final Object page = memoryManager.getPage(fullKeyAddress);
+ final long keyOffsetInPage = memoryManager.getOffsetInPage(fullKeyAddress);
+ long position = keyOffsetInPage;
+ keyLength = (int) PlatformDependent.UNSAFE.getLong(page, position);
+ position += 8; // word used to store the key size
+ keyMemoryLocation.setObjAndOffset(page, position);
+ position += keyLength;
+ valueLength = (int) PlatformDependent.UNSAFE.getLong(page, position);
+ position += 8; // word used to store the key size
+ valueMemoryLocation.setObjAndOffset(page, position);
+ }
+
+ Location with(int pos, int keyHashcode, boolean isDefined) {
+ this.pos = pos;
+ this.isDefined = isDefined;
+ this.keyHashcode = keyHashcode;
+ if (isDefined) {
+ final long fullKeyAddress = longArray.get(pos * 2);
+ updateAddressesAndSizes(fullKeyAddress);
+ }
+ return this;
+ }
+
+ /**
+ * Returns true if the key is defined at this position, and false otherwise.
+ */
+ public boolean isDefined() {
+ return isDefined;
+ }
+
+ /**
+ * Returns the address of the key defined at this position.
+ * This points to the first byte of the key data.
+ * Unspecified behavior if the key is not defined.
+ * For efficiency reasons, calls to this method always returns the same MemoryLocation object.
+ */
+ public MemoryLocation getKeyAddress() {
+ assert (isDefined);
+ return keyMemoryLocation;
+ }
+
+ /**
+ * Returns the length of the key defined at this position.
+ * Unspecified behavior if the key is not defined.
+ */
+ public int getKeyLength() {
+ assert (isDefined);
+ return keyLength;
+ }
+
+ /**
+ * Returns the address of the value defined at this position.
+ * This points to the first byte of the value data.
+ * Unspecified behavior if the key is not defined.
+ * For efficiency reasons, calls to this method always returns the same MemoryLocation object.
+ */
+ public MemoryLocation getValueAddress() {
+ assert (isDefined);
+ return valueMemoryLocation;
+ }
+
+ /**
+ * Returns the length of the value defined at this position.
+ * Unspecified behavior if the key is not defined.
+ */
+ public int getValueLength() {
+ assert (isDefined);
+ return valueLength;
+ }
+
+ /**
+ * Store a new key and value. This method may only be called once for a given key; if you want
+ * to update the value associated with a key, then you can directly manipulate the bytes stored
+ * at the value address.
+ * <p>
+ * It is only valid to call this method immediately after calling `lookup()` using the same key.
+ * <p>
+ * After calling this method, calls to `get[Key|Value]Address()` and `get[Key|Value]Length`
+ * will return information on the data stored by this `putNewKey` call.
+ * <p>
+ * As an example usage, here's the proper way to store a new key:
+ * <p>
+ * <pre>
+ * Location loc = map.lookup(keyBaseOffset, keyBaseObject, keyLengthInBytes);
+ * if (!loc.isDefined()) {
+ * loc.putNewKey(keyBaseOffset, keyBaseObject, keyLengthInBytes, ...)
+ * }
+ * </pre>
+ * <p>
+ * Unspecified behavior if the key is not defined.
+ */
+ public void putNewKey(
+ Object keyBaseObject,
+ long keyBaseOffset,
+ int keyLengthBytes,
+ Object valueBaseObject,
+ long valueBaseOffset,
+ int valueLengthBytes) {
+ assert (!isDefined) : "Can only set value once for a key";
+ isDefined = true;
+ assert (keyLengthBytes % 8 == 0);
+ assert (valueLengthBytes % 8 == 0);
+ // Here, we'll copy the data into our data pages. Because we only store a relative offset from
+ // the key address instead of storing the absolute address of the value, the key and value
+ // must be stored in the same memory page.
+ // (8 byte key length) (key) (8 byte value length) (value)
+ final long requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes;
+ assert(requiredSize <= PAGE_SIZE_BYTES);
+ size++;
+ bitset.set(pos);
+
+ // If there's not enough space in the current page, allocate a new page:
+ if (currentDataPage == null || PAGE_SIZE_BYTES - pageCursor < requiredSize) {
+ MemoryBlock newPage = memoryManager.allocatePage(PAGE_SIZE_BYTES);
+ dataPages.add(newPage);
+ pageCursor = 0;
+ currentDataPage = newPage;
+ }
+
+ // Compute all of our offsets up-front:
+ final Object pageBaseObject = currentDataPage.getBaseObject();
+ final long pageBaseOffset = currentDataPage.getBaseOffset();
+ final long keySizeOffsetInPage = pageBaseOffset + pageCursor;
+ pageCursor += 8; // word used to store the key size
+ final long keyDataOffsetInPage = pageBaseOffset + pageCursor;
+ pageCursor += keyLengthBytes;
+ final long valueSizeOffsetInPage = pageBaseOffset + pageCursor;
+ pageCursor += 8; // word used to store the value size
+ final long valueDataOffsetInPage = pageBaseOffset + pageCursor;
+ pageCursor += valueLengthBytes;
+
+ // Copy the key
+ PlatformDependent.UNSAFE.putLong(pageBaseObject, keySizeOffsetInPage, keyLengthBytes);
+ PlatformDependent.UNSAFE.copyMemory(
+ keyBaseObject, keyBaseOffset, pageBaseObject, keyDataOffsetInPage, keyLengthBytes);
+ // Copy the value
+ PlatformDependent.UNSAFE.putLong(pageBaseObject, valueSizeOffsetInPage, valueLengthBytes);
+ PlatformDependent.UNSAFE.copyMemory(
+ valueBaseObject, valueBaseOffset, pageBaseObject, valueDataOffsetInPage, valueLengthBytes);
+
+ final long storedKeyAddress = memoryManager.encodePageNumberAndOffset(
+ currentDataPage, keySizeOffsetInPage);
+ longArray.set(pos * 2, storedKeyAddress);
+ longArray.set(pos * 2 + 1, keyHashcode);
+ updateAddressesAndSizes(storedKeyAddress);
+ isDefined = true;
+ if (size > growthThreshold) {
+ growAndRehash();
+ }
+ }
+ }
+
+ /**
+ * Allocate new data structures for this map. When calling this outside of the constructor,
+ * make sure to keep references to the old data structures so that you can free them.
+ *
+ * @param capacity the new map capacity
+ */
+ private void allocate(int capacity) {
+ capacity = Math.max((int) Math.min(Integer.MAX_VALUE, nextPowerOf2(capacity)), 64);
+ longArray = new LongArray(memoryManager.allocate(capacity * 8 * 2));
+ bitset = new BitSet(memoryManager.allocate(capacity / 8).zero());
+
+ this.growthThreshold = (int) (capacity * loadFactor);
+ this.mask = capacity - 1;
+ }
+
+ /**
+ * Free all allocated memory associated with this map, including the storage for keys and values
+ * as well as the hash map array itself.
+ *
+ * This method is idempotent.
+ */
+ public void free() {
+ if (longArray != null) {
+ memoryManager.free(longArray.memoryBlock());
+ longArray = null;
+ }
+ if (bitset != null) {
+ memoryManager.free(bitset.memoryBlock());
+ bitset = null;
+ }
+ Iterator<MemoryBlock> dataPagesIterator = dataPages.iterator();
+ while (dataPagesIterator.hasNext()) {
+ memoryManager.freePage(dataPagesIterator.next());
+ dataPagesIterator.remove();
+ }
+ assert(dataPages.isEmpty());
+ }
+
+ /** Returns the total amount of memory, in bytes, consumed by this map's managed structures. */
+ public long getTotalMemoryConsumption() {
+ return (
+ dataPages.size() * PAGE_SIZE_BYTES +
+ bitset.memoryBlock().size() +
+ longArray.memoryBlock().size());
+ }
+
+ /**
+ * Returns the total amount of time spent resizing this map (in nanoseconds).
+ */
+ public long getTimeSpentResizingNs() {
+ if (!enablePerfMetrics) {
+ throw new IllegalStateException();
+ }
+ return timeSpentResizingNs;
+ }
+
+
+ /**
+ * Returns the average number of probes per key lookup.
+ */
+ public double getAverageProbesPerLookup() {
+ if (!enablePerfMetrics) {
+ throw new IllegalStateException();
+ }
+ return (1.0 * numProbes) / numKeyLookups;
+ }
+
+ public long getNumHashCollisions() {
+ if (!enablePerfMetrics) {
+ throw new IllegalStateException();
+ }
+ return numHashCollisions;
+ }
+
+ /**
+ * Grows the size of the hash table and re-hash everything.
+ */
+ private void growAndRehash() {
+ long resizeStartTime = -1;
+ if (enablePerfMetrics) {
+ resizeStartTime = System.nanoTime();
+ }
+ // Store references to the old data structures to be used when we re-hash
+ final LongArray oldLongArray = longArray;
+ final BitSet oldBitSet = bitset;
+ final int oldCapacity = (int) oldBitSet.capacity();
+
+ // Allocate the new data structures
+ allocate(Math.min(Integer.MAX_VALUE, growthStrategy.nextCapacity(oldCapacity)));
+
+ // Re-mask (we don't recompute the hashcode because we stored all 32 bits of it)
+ for (int pos = oldBitSet.nextSetBit(0); pos >= 0; pos = oldBitSet.nextSetBit(pos + 1)) {
+ final long keyPointer = oldLongArray.get(pos * 2);
+ final int hashcode = (int) oldLongArray.get(pos * 2 + 1);
+ int newPos = hashcode & mask;
+ int step = 1;
+ boolean keepGoing = true;
+
+ // No need to check for equality here when we insert so this has one less if branch than
+ // the similar code path in addWithoutResize.
+ while (keepGoing) {
+ if (!bitset.isSet(newPos)) {
+ bitset.set(newPos);
+ longArray.set(newPos * 2, keyPointer);
+ longArray.set(newPos * 2 + 1, hashcode);
+ keepGoing = false;
+ } else {
+ newPos = (newPos + step) & mask;
+ step++;
+ }
+ }
+ }
+
+ // Deallocate the old data structures.
+ memoryManager.free(oldLongArray.memoryBlock());
+ memoryManager.free(oldBitSet.memoryBlock());
+ if (enablePerfMetrics) {
+ timeSpentResizingNs += System.nanoTime() - resizeStartTime;
+ }
+ }
+
+ /** Returns the next number greater or equal num that is power of 2. */
+ private static long nextPowerOf2(long num) {
+ final long highBit = Long.highestOneBit(num);
+ return (highBit == num) ? num : highBit << 1;
+ }
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java
new file mode 100644
index 0000000000..7c321baffe
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.map;
+
+/**
+ * Interface that defines how we can grow the size of a hash map when it is over a threshold.
+ */
+public interface HashMapGrowthStrategy {
+
+ int nextCapacity(int currentCapacity);
+
+ /**
+ * Double the size of the hash map every time.
+ */
+ HashMapGrowthStrategy DOUBLING = new Doubling();
+
+ class Doubling implements HashMapGrowthStrategy {
+ @Override
+ public int nextCapacity(int currentCapacity) {
+ return currentCapacity * 2;
+ }
+ }
+
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java
new file mode 100644
index 0000000000..62c29c8cc1
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.memory;
+
+/**
+ * Manages memory for an executor. Individual operators / tasks allocate memory through
+ * {@link TaskMemoryManager} objects, which obtain their memory from ExecutorMemoryManager.
+ */
+public class ExecutorMemoryManager {
+
+ /**
+ * Allocator, exposed for enabling untracked allocations of temporary data structures.
+ */
+ public final MemoryAllocator allocator;
+
+ /**
+ * Tracks whether memory will be allocated on the JVM heap or off-heap using sun.misc.Unsafe.
+ */
+ final boolean inHeap;
+
+ /**
+ * Construct a new ExecutorMemoryManager.
+ *
+ * @param allocator the allocator that will be used
+ */
+ public ExecutorMemoryManager(MemoryAllocator allocator) {
+ this.inHeap = allocator instanceof HeapMemoryAllocator;
+ this.allocator = allocator;
+ }
+
+ /**
+ * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed
+ * to be zeroed out (call `zero()` on the result if this is necessary).
+ */
+ MemoryBlock allocate(long size) throws OutOfMemoryError {
+ return allocator.allocate(size);
+ }
+
+ void free(MemoryBlock memory) {
+ allocator.free(memory);
+ }
+
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
new file mode 100644
index 0000000000..bbe83d36cf
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.memory;
+
+/**
+ * A simple {@link MemoryAllocator} that can allocate up to 16GB using a JVM long primitive array.
+ */
+public class HeapMemoryAllocator implements MemoryAllocator {
+
+ @Override
+ public MemoryBlock allocate(long size) throws OutOfMemoryError {
+ long[] array = new long[(int) (size / 8)];
+ return MemoryBlock.fromLongArray(array);
+ }
+
+ @Override
+ public void free(MemoryBlock memory) {
+ // Do nothing
+ }
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java
new file mode 100644
index 0000000000..5192f68c86
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.memory;
+
+public interface MemoryAllocator {
+
+ /**
+ * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed
+ * to be zeroed out (call `zero()` on the result if this is necessary).
+ */
+ MemoryBlock allocate(long size) throws OutOfMemoryError;
+
+ void free(MemoryBlock memory);
+
+ MemoryAllocator UNSAFE = new UnsafeMemoryAllocator();
+
+ MemoryAllocator HEAP = new HeapMemoryAllocator();
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
new file mode 100644
index 0000000000..0beb743e56
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.memory;
+
+import javax.annotation.Nullable;
+
+import org.apache.spark.unsafe.PlatformDependent;
+
+/**
+ * A consecutive block of memory, starting at a {@link MemoryLocation} with a fixed size.
+ */
+public class MemoryBlock extends MemoryLocation {
+
+ private final long length;
+
+ /**
+ * Optional page number; used when this MemoryBlock represents a page allocated by a
+ * MemoryManager. This is package-private and is modified by MemoryManager.
+ */
+ int pageNumber = -1;
+
+ MemoryBlock(@Nullable Object obj, long offset, long length) {
+ super(obj, offset);
+ this.length = length;
+ }
+
+ /**
+ * Returns the size of the memory block.
+ */
+ public long size() {
+ return length;
+ }
+
+ /**
+ * Clear the contents of this memory block. Returns `this` to facilitate chaining.
+ */
+ public MemoryBlock zero() {
+ PlatformDependent.UNSAFE.setMemory(obj, offset, length, (byte) 0);
+ return this;
+ }
+
+ /**
+ * Creates a memory block pointing to the memory used by the long array.
+ */
+ public static MemoryBlock fromLongArray(final long[] array) {
+ return new MemoryBlock(array, PlatformDependent.LONG_ARRAY_OFFSET, array.length * 8);
+ }
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java
new file mode 100644
index 0000000000..74ebc87dc9
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.memory;
+
+import javax.annotation.Nullable;
+
+/**
+ * A memory location. Tracked either by a memory address (with off-heap allocation),
+ * or by an offset from a JVM object (in-heap allocation).
+ */
+public class MemoryLocation {
+
+ @Nullable
+ Object obj;
+
+ long offset;
+
+ public MemoryLocation(@Nullable Object obj, long offset) {
+ this.obj = obj;
+ this.offset = offset;
+ }
+
+ public MemoryLocation() {
+ this(null, 0);
+ }
+
+ public void setObjAndOffset(Object newObj, long newOffset) {
+ this.obj = newObj;
+ this.offset = newOffset;
+ }
+
+ public final Object getBaseObject() {
+ return obj;
+ }
+
+ public final long getBaseOffset() {
+ return offset;
+ }
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
new file mode 100644
index 0000000000..9224988e6a
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.memory;
+
+import java.util.*;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Manages the memory allocated by an individual task.
+ * <p>
+ * Most of the complexity in this class deals with encoding of off-heap addresses into 64-bit longs.
+ * In off-heap mode, memory can be directly addressed with 64-bit longs. In on-heap mode, memory is
+ * addressed by the combination of a base Object reference and a 64-bit offset within that object.
+ * This is a problem when we want to store pointers to data structures inside of other structures,
+ * such as record pointers inside hashmaps or sorting buffers. Even if we decided to use 128 bits
+ * to address memory, we can't just store the address of the base object since it's not guaranteed
+ * to remain stable as the heap gets reorganized due to GC.
+ * <p>
+ * Instead, we use the following approach to encode record pointers in 64-bit longs: for off-heap
+ * mode, just store the raw address, and for on-heap mode use the upper 13 bits of the address to
+ * store a "page number" and the lower 51 bits to store an offset within this page. These page
+ * numbers are used to index into a "page table" array inside of the MemoryManager in order to
+ * retrieve the base object.
+ * <p>
+ * This allows us to address 8192 pages. In on-heap mode, the maximum page size is limited by the
+ * maximum size of a long[] array, allowing us to address 8192 * 2^32 * 8 bytes, which is
+ * approximately 35 terabytes of memory.
+ */
+public final class TaskMemoryManager {
+
+ private final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class);
+
+ /**
+ * The number of entries in the page table.
+ */
+ private static final int PAGE_TABLE_SIZE = 1 << 13;
+
+ /** Bit mask for the lower 51 bits of a long. */
+ private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL;
+
+ /** Bit mask for the upper 13 bits of a long */
+ private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS;
+
+ /**
+ * Similar to an operating system's page table, this array maps page numbers into base object
+ * pointers, allowing us to translate between the hashtable's internal 64-bit address
+ * representation and the baseObject+offset representation which we use to support both in- and
+ * off-heap addresses. When using an off-heap allocator, every entry in this map will be `null`.
+ * When using an in-heap allocator, the entries in this map will point to pages' base objects.
+ * Entries are added to this map as new data pages are allocated.
+ */
+ private final MemoryBlock[] pageTable = new MemoryBlock[PAGE_TABLE_SIZE];
+
+ /**
+ * Bitmap for tracking free pages.
+ */
+ private final BitSet allocatedPages = new BitSet(PAGE_TABLE_SIZE);
+
+ /**
+ * Tracks memory allocated with {@link TaskMemoryManager#allocate(long)}, used to detect / clean
+ * up leaked memory.
+ */
+ private final HashSet<MemoryBlock> allocatedNonPageMemory = new HashSet<MemoryBlock>();
+
+ private final ExecutorMemoryManager executorMemoryManager;
+
+ /**
+ * Tracks whether we're in-heap or off-heap. For off-heap, we short-circuit most of these methods
+ * without doing any masking or lookups. Since this branching should be well-predicted by the JIT,
+ * this extra layer of indirection / abstraction hopefully shouldn't be too expensive.
+ */
+ private final boolean inHeap;
+
+ /**
+ * Construct a new MemoryManager.
+ */
+ public TaskMemoryManager(ExecutorMemoryManager executorMemoryManager) {
+ this.inHeap = executorMemoryManager.inHeap;
+ this.executorMemoryManager = executorMemoryManager;
+ }
+
+ /**
+ * Allocate a block of memory that will be tracked in the MemoryManager's page table; this is
+ * intended for allocating large blocks of memory that will be shared between operators.
+ */
+ public MemoryBlock allocatePage(long size) {
+ if (logger.isTraceEnabled()) {
+ logger.trace("Allocating {} byte page", size);
+ }
+ if (size >= (1L << 51)) {
+ throw new IllegalArgumentException("Cannot allocate a page with more than 2^51 bytes");
+ }
+
+ final int pageNumber;
+ synchronized (this) {
+ pageNumber = allocatedPages.nextClearBit(0);
+ if (pageNumber >= PAGE_TABLE_SIZE) {
+ throw new IllegalStateException(
+ "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages");
+ }
+ allocatedPages.set(pageNumber);
+ }
+ final MemoryBlock page = executorMemoryManager.allocate(size);
+ page.pageNumber = pageNumber;
+ pageTable[pageNumber] = page;
+ if (logger.isDebugEnabled()) {
+ logger.debug("Allocate page number {} ({} bytes)", pageNumber, size);
+ }
+ return page;
+ }
+
+ /**
+ * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}.
+ */
+ public void freePage(MemoryBlock page) {
+ if (logger.isTraceEnabled()) {
+ logger.trace("Freeing page number {} ({} bytes)", page.pageNumber, page.size());
+ }
+ assert (page.pageNumber != -1) :
+ "Called freePage() on memory that wasn't allocated with allocatePage()";
+ executorMemoryManager.free(page);
+ synchronized (this) {
+ allocatedPages.clear(page.pageNumber);
+ }
+ pageTable[page.pageNumber] = null;
+ if (logger.isDebugEnabled()) {
+ logger.debug("Freed page number {} ({} bytes)", page.pageNumber, page.size());
+ }
+ }
+
+ /**
+ * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed
+ * to be zeroed out (call `zero()` on the result if this is necessary). This method is intended
+ * to be used for allocating operators' internal data structures. For data pages that you want to
+ * exchange between operators, consider using {@link TaskMemoryManager#allocatePage(long)}, since
+ * that will enable intra-memory pointers (see
+ * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} and this class's
+ * top-level Javadoc for more details).
+ */
+ public MemoryBlock allocate(long size) throws OutOfMemoryError {
+ final MemoryBlock memory = executorMemoryManager.allocate(size);
+ allocatedNonPageMemory.add(memory);
+ return memory;
+ }
+
+ /**
+ * Free memory allocated by {@link TaskMemoryManager#allocate(long)}.
+ */
+ public void free(MemoryBlock memory) {
+ assert (memory.pageNumber == -1) : "Should call freePage() for pages, not free()";
+ executorMemoryManager.free(memory);
+ final boolean wasAlreadyRemoved = !allocatedNonPageMemory.remove(memory);
+ assert (!wasAlreadyRemoved) : "Called free() on memory that was already freed!";
+ }
+
+ /**
+ * Given a memory page and offset within that page, encode this address into a 64-bit long.
+ * This address will remain valid as long as the corresponding page has not been freed.
+ */
+ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) {
+ if (inHeap) {
+ assert (page.pageNumber != -1) : "encodePageNumberAndOffset called with invalid page";
+ return (((long) page.pageNumber) << 51) | (offsetInPage & MASK_LONG_LOWER_51_BITS);
+ } else {
+ return offsetInPage;
+ }
+ }
+
+ /**
+ * Get the page associated with an address encoded by
+ * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)}
+ */
+ public Object getPage(long pagePlusOffsetAddress) {
+ if (inHeap) {
+ final int pageNumber = (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> 51);
+ assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
+ final Object page = pageTable[pageNumber].getBaseObject();
+ assert (page != null);
+ return page;
+ } else {
+ return null;
+ }
+ }
+
+ /**
+ * Get the offset associated with an address encoded by
+ * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)}
+ */
+ public long getOffsetInPage(long pagePlusOffsetAddress) {
+ if (inHeap) {
+ return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS);
+ } else {
+ return pagePlusOffsetAddress;
+ }
+ }
+
+ /**
+ * Clean up all allocated memory and pages. Returns the number of bytes freed. A non-zero return
+ * value can be used to detect memory leaks.
+ */
+ public long cleanUpAllAllocatedMemory() {
+ long freedBytes = 0;
+ for (MemoryBlock page : pageTable) {
+ if (page != null) {
+ freedBytes += page.size();
+ freePage(page);
+ }
+ }
+ final Iterator<MemoryBlock> iter = allocatedNonPageMemory.iterator();
+ while (iter.hasNext()) {
+ final MemoryBlock memory = iter.next();
+ freedBytes += memory.size();
+ // We don't call free() here because that calls Set.remove, which would lead to a
+ // ConcurrentModificationException here.
+ executorMemoryManager.free(memory);
+ iter.remove();
+ }
+ return freedBytes;
+ }
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
new file mode 100644
index 0000000000..15898771fe
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.memory;
+
+import org.apache.spark.unsafe.PlatformDependent;
+
+/**
+ * A simple {@link MemoryAllocator} that uses {@code Unsafe} to allocate off-heap memory.
+ */
+public class UnsafeMemoryAllocator implements MemoryAllocator {
+
+ @Override
+ public MemoryBlock allocate(long size) throws OutOfMemoryError {
+ long address = PlatformDependent.UNSAFE.allocateMemory(size);
+ return new MemoryBlock(null, address, size);
+ }
+
+ @Override
+ public void free(MemoryBlock memory) {
+ assert (memory.obj == null) :
+ "baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?";
+ PlatformDependent.UNSAFE.freeMemory(memory.offset);
+ }
+}
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java
new file mode 100644
index 0000000000..5974cf91ff
--- /dev/null
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.array;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.spark.unsafe.memory.MemoryBlock;
+
+public class LongArraySuite {
+
+ @Test
+ public void basicTest() {
+ long[] bytes = new long[2];
+ LongArray arr = new LongArray(MemoryBlock.fromLongArray(bytes));
+ arr.set(0, 1L);
+ arr.set(1, 2L);
+ arr.set(1, 3L);
+ Assert.assertEquals(2, arr.size());
+ Assert.assertEquals(1L, arr.get(0));
+ Assert.assertEquals(3L, arr.get(1));
+ }
+}
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java
new file mode 100644
index 0000000000..4bf132fd40
--- /dev/null
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.bitset;
+
+import junit.framework.Assert;
+import org.apache.spark.unsafe.bitset.BitSet;
+import org.junit.Test;
+
+import org.apache.spark.unsafe.memory.MemoryBlock;
+
+public class BitSetSuite {
+
+ private static BitSet createBitSet(int capacity) {
+ assert capacity % 64 == 0;
+ return new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64]).zero());
+ }
+
+ @Test
+ public void basicOps() {
+ BitSet bs = createBitSet(64);
+ Assert.assertEquals(64, bs.capacity());
+
+ // Make sure the bit set starts empty.
+ for (int i = 0; i < bs.capacity(); i++) {
+ Assert.assertFalse(bs.isSet(i));
+ }
+
+ // Set every bit and check it.
+ for (int i = 0; i < bs.capacity(); i++) {
+ bs.set(i);
+ Assert.assertTrue(bs.isSet(i));
+ }
+
+ // Unset every bit and check it.
+ for (int i = 0; i < bs.capacity(); i++) {
+ Assert.assertTrue(bs.isSet(i));
+ bs.unset(i);
+ Assert.assertFalse(bs.isSet(i));
+ }
+ }
+
+ @Test
+ public void traversal() {
+ BitSet bs = createBitSet(256);
+
+ Assert.assertEquals(-1, bs.nextSetBit(0));
+ Assert.assertEquals(-1, bs.nextSetBit(10));
+ Assert.assertEquals(-1, bs.nextSetBit(64));
+
+ bs.set(10);
+ Assert.assertEquals(10, bs.nextSetBit(0));
+ Assert.assertEquals(10, bs.nextSetBit(1));
+ Assert.assertEquals(10, bs.nextSetBit(10));
+ Assert.assertEquals(-1, bs.nextSetBit(11));
+
+ bs.set(11);
+ Assert.assertEquals(10, bs.nextSetBit(10));
+ Assert.assertEquals(11, bs.nextSetBit(11));
+
+ // Skip a whole word and find it
+ bs.set(190);
+ Assert.assertEquals(190, bs.nextSetBit(12));
+
+ Assert.assertEquals(-1, bs.nextSetBit(191));
+ Assert.assertEquals(-1, bs.nextSetBit(256));
+ }
+}
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java
new file mode 100644
index 0000000000..3b91758352
--- /dev/null
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.hash;
+
+import java.util.HashSet;
+import java.util.Random;
+import java.util.Set;
+
+import junit.framework.Assert;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.junit.Test;
+
+/**
+ * Test file based on Guava's Murmur3Hash32Test.
+ */
+public class Murmur3_x86_32Suite {
+
+ private static final Murmur3_x86_32 hasher = new Murmur3_x86_32(0);
+
+ @Test
+ public void testKnownIntegerInputs() {
+ Assert.assertEquals(593689054, hasher.hashInt(0));
+ Assert.assertEquals(-189366624, hasher.hashInt(-42));
+ Assert.assertEquals(-1134849565, hasher.hashInt(42));
+ Assert.assertEquals(-1718298732, hasher.hashInt(Integer.MIN_VALUE));
+ Assert.assertEquals(-1653689534, hasher.hashInt(Integer.MAX_VALUE));
+ }
+
+ @Test
+ public void testKnownLongInputs() {
+ Assert.assertEquals(1669671676, hasher.hashLong(0L));
+ Assert.assertEquals(-846261623, hasher.hashLong(-42L));
+ Assert.assertEquals(1871679806, hasher.hashLong(42L));
+ Assert.assertEquals(1366273829, hasher.hashLong(Long.MIN_VALUE));
+ Assert.assertEquals(-2106506049, hasher.hashLong(Long.MAX_VALUE));
+ }
+
+ @Test
+ public void randomizedStressTest() {
+ int size = 65536;
+ Random rand = new Random();
+
+ // A set used to track collision rate.
+ Set<Integer> hashcodes = new HashSet<Integer>();
+ for (int i = 0; i < size; i++) {
+ int vint = rand.nextInt();
+ long lint = rand.nextLong();
+ Assert.assertEquals(hasher.hashInt(vint), hasher.hashInt(vint));
+ Assert.assertEquals(hasher.hashLong(lint), hasher.hashLong(lint));
+
+ hashcodes.add(hasher.hashLong(lint));
+ }
+
+ // A very loose bound.
+ Assert.assertTrue(hashcodes.size() > size * 0.95);
+ }
+
+ @Test
+ public void randomizedStressTestBytes() {
+ int size = 65536;
+ Random rand = new Random();
+
+ // A set used to track collision rate.
+ Set<Integer> hashcodes = new HashSet<Integer>();
+ for (int i = 0; i < size; i++) {
+ int byteArrSize = rand.nextInt(100) * 8;
+ byte[] bytes = new byte[byteArrSize];
+ rand.nextBytes(bytes);
+
+ Assert.assertEquals(
+ hasher.hashUnsafeWords(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize),
+ hasher.hashUnsafeWords(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize));
+
+ hashcodes.add(hasher.hashUnsafeWords(
+ bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize));
+ }
+
+ // A very loose bound.
+ Assert.assertTrue(hashcodes.size() > size * 0.95);
+ }
+
+ @Test
+ public void randomizedStressTestPaddedStrings() {
+ int size = 64000;
+ // A set used to track collision rate.
+ Set<Integer> hashcodes = new HashSet<Integer>();
+ for (int i = 0; i < size; i++) {
+ int byteArrSize = 8;
+ byte[] strBytes = ("" + i).getBytes();
+ byte[] paddedBytes = new byte[byteArrSize];
+ System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length);
+
+ Assert.assertEquals(
+ hasher.hashUnsafeWords(paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize),
+ hasher.hashUnsafeWords(paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize));
+
+ hashcodes.add(hasher.hashUnsafeWords(
+ paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize));
+ }
+
+ // A very loose bound.
+ Assert.assertTrue(hashcodes.size() > size * 0.95);
+ }
+}
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
new file mode 100644
index 0000000000..9038cf567f
--- /dev/null
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -0,0 +1,250 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.map;
+
+import java.lang.Exception;
+import java.nio.ByteBuffer;
+import java.util.*;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.PlatformDependent;
+import static org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET;
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.MemoryLocation;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+public abstract class AbstractBytesToBytesMapSuite {
+
+ private final Random rand = new Random(42);
+
+ private TaskMemoryManager memoryManager;
+
+ @Before
+ public void setup() {
+ memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator()));
+ }
+
+ @After
+ public void tearDown() {
+ if (memoryManager != null) {
+ memoryManager.cleanUpAllAllocatedMemory();
+ memoryManager = null;
+ }
+ }
+
+ protected abstract MemoryAllocator getMemoryAllocator();
+
+ private static byte[] getByteArray(MemoryLocation loc, int size) {
+ final byte[] arr = new byte[size];
+ PlatformDependent.UNSAFE.copyMemory(
+ loc.getBaseObject(),
+ loc.getBaseOffset(),
+ arr,
+ BYTE_ARRAY_OFFSET,
+ size
+ );
+ return arr;
+ }
+
+ private byte[] getRandomByteArray(int numWords) {
+ Assert.assertTrue(numWords > 0);
+ final int lengthInBytes = numWords * 8;
+ final byte[] bytes = new byte[lengthInBytes];
+ rand.nextBytes(bytes);
+ return bytes;
+ }
+
+ /**
+ * Fast equality checking for byte arrays, since these comparisons are a bottleneck
+ * in our stress tests.
+ */
+ private static boolean arrayEquals(
+ byte[] expected,
+ MemoryLocation actualAddr,
+ long actualLengthBytes) {
+ return (actualLengthBytes == expected.length) && ByteArrayMethods.wordAlignedArrayEquals(
+ expected,
+ BYTE_ARRAY_OFFSET,
+ actualAddr.getBaseObject(),
+ actualAddr.getBaseOffset(),
+ expected.length
+ );
+ }
+
+ @Test
+ public void emptyMap() {
+ BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64);
+ try {
+ Assert.assertEquals(0, map.size());
+ final int keyLengthInWords = 10;
+ final int keyLengthInBytes = keyLengthInWords * 8;
+ final byte[] key = getRandomByteArray(keyLengthInWords);
+ Assert.assertFalse(map.lookup(key, BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined());
+ } finally {
+ map.free();
+ }
+ }
+
+ @Test
+ public void setAndRetrieveAKey() {
+ BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64);
+ final int recordLengthWords = 10;
+ final int recordLengthBytes = recordLengthWords * 8;
+ final byte[] keyData = getRandomByteArray(recordLengthWords);
+ final byte[] valueData = getRandomByteArray(recordLengthWords);
+ try {
+ final BytesToBytesMap.Location loc =
+ map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes);
+ Assert.assertFalse(loc.isDefined());
+ loc.putNewKey(
+ keyData,
+ BYTE_ARRAY_OFFSET,
+ recordLengthBytes,
+ valueData,
+ BYTE_ARRAY_OFFSET,
+ recordLengthBytes
+ );
+ // After storing the key and value, the other location methods should return results that
+ // reflect the result of this store without us having to call lookup() again on the same key.
+ Assert.assertEquals(recordLengthBytes, loc.getKeyLength());
+ Assert.assertEquals(recordLengthBytes, loc.getValueLength());
+ Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes));
+ Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes));
+
+ // After calling lookup() the location should still point to the correct data.
+ Assert.assertTrue(map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined());
+ Assert.assertEquals(recordLengthBytes, loc.getKeyLength());
+ Assert.assertEquals(recordLengthBytes, loc.getValueLength());
+ Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes));
+ Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes));
+
+ try {
+ loc.putNewKey(
+ keyData,
+ BYTE_ARRAY_OFFSET,
+ recordLengthBytes,
+ valueData,
+ BYTE_ARRAY_OFFSET,
+ recordLengthBytes
+ );
+ Assert.fail("Should not be able to set a new value for a key");
+ } catch (AssertionError e) {
+ // Expected exception; do nothing.
+ }
+ } finally {
+ map.free();
+ }
+ }
+
+ @Test
+ public void iteratorTest() throws Exception {
+ final int size = 128;
+ BytesToBytesMap map = new BytesToBytesMap(memoryManager, size / 2);
+ try {
+ for (long i = 0; i < size; i++) {
+ final long[] value = new long[] { i };
+ final BytesToBytesMap.Location loc =
+ map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8);
+ Assert.assertFalse(loc.isDefined());
+ loc.putNewKey(
+ value,
+ PlatformDependent.LONG_ARRAY_OFFSET,
+ 8,
+ value,
+ PlatformDependent.LONG_ARRAY_OFFSET,
+ 8
+ );
+ }
+ final java.util.BitSet valuesSeen = new java.util.BitSet(size);
+ final Iterator<BytesToBytesMap.Location> iter = map.iterator();
+ while (iter.hasNext()) {
+ final BytesToBytesMap.Location loc = iter.next();
+ Assert.assertTrue(loc.isDefined());
+ final MemoryLocation keyAddress = loc.getKeyAddress();
+ final MemoryLocation valueAddress = loc.getValueAddress();
+ final long key = PlatformDependent.UNSAFE.getLong(
+ keyAddress.getBaseObject(), keyAddress.getBaseOffset());
+ final long value = PlatformDependent.UNSAFE.getLong(
+ valueAddress.getBaseObject(), valueAddress.getBaseOffset());
+ Assert.assertEquals(key, value);
+ valuesSeen.set((int) value);
+ }
+ Assert.assertEquals(size, valuesSeen.cardinality());
+ } finally {
+ map.free();
+ }
+ }
+
+ @Test
+ public void randomizedStressTest() {
+ final int size = 65536;
+ // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays
+ // into ByteBuffers in order to use them as keys here.
+ final Map<ByteBuffer, byte[]> expected = new HashMap<ByteBuffer, byte[]>();
+ final BytesToBytesMap map = new BytesToBytesMap(memoryManager, size);
+
+ try {
+ // Fill the map to 90% full so that we can trigger probing
+ for (int i = 0; i < size * 0.9; i++) {
+ final byte[] key = getRandomByteArray(rand.nextInt(256) + 1);
+ final byte[] value = getRandomByteArray(rand.nextInt(512) + 1);
+ if (!expected.containsKey(ByteBuffer.wrap(key))) {
+ expected.put(ByteBuffer.wrap(key), value);
+ final BytesToBytesMap.Location loc = map.lookup(
+ key,
+ BYTE_ARRAY_OFFSET,
+ key.length
+ );
+ Assert.assertFalse(loc.isDefined());
+ loc.putNewKey(
+ key,
+ BYTE_ARRAY_OFFSET,
+ key.length,
+ value,
+ BYTE_ARRAY_OFFSET,
+ value.length
+ );
+ // After calling putNewKey, the following should be true, even before calling
+ // lookup():
+ Assert.assertTrue(loc.isDefined());
+ Assert.assertEquals(key.length, loc.getKeyLength());
+ Assert.assertEquals(value.length, loc.getValueLength());
+ Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length));
+ Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length));
+ }
+ }
+
+ for (Map.Entry<ByteBuffer, byte[]> entry : expected.entrySet()) {
+ final byte[] key = entry.getKey().array();
+ final byte[] value = entry.getValue();
+ final BytesToBytesMap.Location loc = map.lookup(key, BYTE_ARRAY_OFFSET, key.length);
+ Assert.assertTrue(loc.isDefined());
+ Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength()));
+ Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength()));
+ }
+ } finally {
+ map.free();
+ }
+ }
+}
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java
new file mode 100644
index 0000000000..5a10de49f5
--- /dev/null
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.map;
+
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+
+public class BytesToBytesMapOffHeapSuite extends AbstractBytesToBytesMapSuite {
+
+ @Override
+ protected MemoryAllocator getMemoryAllocator() {
+ return MemoryAllocator.UNSAFE;
+ }
+
+}
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java
new file mode 100644
index 0000000000..12cc9b25d9
--- /dev/null
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.map;
+
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+
+public class BytesToBytesMapOnHeapSuite extends AbstractBytesToBytesMapSuite {
+
+ @Override
+ protected MemoryAllocator getMemoryAllocator() {
+ return MemoryAllocator.HEAP;
+ }
+
+}
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java
new file mode 100644
index 0000000000..932882f1ca
--- /dev/null
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.memory;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TaskMemoryManagerSuite {
+
+ @Test
+ public void leakedNonPageMemoryIsDetected() {
+ final TaskMemoryManager manager =
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ manager.allocate(1024); // leak memory
+ Assert.assertEquals(1024, manager.cleanUpAllAllocatedMemory());
+ }
+
+ @Test
+ public void leakedPageMemoryIsDetected() {
+ final TaskMemoryManager manager =
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ manager.allocatePage(4096); // leak memory
+ Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory());
+ }
+
+}