diff options
author | Jakob Odersky <jakob@odersky.com> | 2017-03-29 11:55:51 -0700 |
---|---|---|
committer | Jakob Odersky <jakob@odersky.com> | 2017-04-24 14:09:49 -0700 |
commit | a043265e009c32852a0e6d37349921c9c3a5bf28 (patch) | |
tree | 65c4373df4f5cdadc2c5f773ec86f47d2f76a5ef | |
parent | a3860c59deebf996f0c32bcc0d15b2903216e732 (diff) | |
download | spark-a043265e009c32852a0e6d37349921c9c3a5bf28.tar.gz spark-a043265e009c32852a0e6d37349921c9c3a5bf28.tar.bz2 spark-a043265e009c32852a0e6d37349921c9c3a5bf28.zip |
Add implicit function conversions to task listeners
3 files changed, 21 insertions, 36 deletions
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index ff9d8dd163..dcae4a34c4 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -37,6 +37,7 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.UnsafeAlignedOffset; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.util.TaskCompletionListener; import org.apache.spark.util.Utils; /** @@ -161,7 +162,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at // the end of the task. This is necessary to avoid memory leaks in when the downstream operator // does not fully consume the sorter's output (e.g. sort followed by limit). - taskContext.addJavaFriendlyTaskCompletionListener( + taskContext.addTaskCompletionListener( new TaskCompletionListener() { @Override public void onTaskCompletion(TaskContext context) { diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 742828a227..1e9cb0a432 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -103,21 +103,6 @@ abstract class TaskContext extends Serializable { @deprecated("Local execution was removed, so this always returns false", "2.0.0") def isRunningLocally(): Boolean - // TODO(josh): this used to be an overload of addTaskCompletionListener(), but the overload - // became ambiguous under Scala 2.12. For now, I'm renaming this in order to get the code to - // compile, but we need to figure out a long-term solution which maintains at least source - // compatibility (and probably binary compatibility) for Java callers. - /** - * Adds a (Java friendly) listener to be executed on task completion. - * This will be called in all situations - success, failure, or cancellation. Adding a listener - * to an already completed task will result in that listener being called immediately. - * - * An example use is for HadoopRDD to register a callback to close the input stream. - * - * Exceptions thrown by the listener will result in failure of the task. - */ - def addJavaFriendlyTaskCompletionListener(listener: TaskCompletionListener): TaskContext - /** * Adds a listener in the form of a Scala closure to be executed on task completion. * This will be called in all situations - success, failure, or cancellation. Adding a listener @@ -127,31 +112,13 @@ abstract class TaskContext extends Serializable { * * Exceptions thrown by the listener will result in failure of the task. */ - def addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext = { - addJavaFriendlyTaskCompletionListener(new TaskCompletionListener { - override def onTaskCompletion(context: TaskContext): Unit = f(context) - }) - } - - // TODO(josh): this used to be an overload of addTaskFailureListener(), but the overload - // became ambiguous under Scala 2.12. For now, I'm renaming this in order to get the code to - // compile, but we need to figure out a long-term solution which maintains at least source - // compatibility (and probably binary compatibility) for Java callers. - /** - * Adds a listener to be executed on task failure. Adding a listener to an already failed task - * will result in that listener being called immediately. - */ - def addJavaFriendlyTaskFailureListener(listener: TaskFailureListener): TaskContext + def addTaskCompletionListener(listener: TaskCompletionListener): TaskContext /** * Adds a listener to be executed on task failure. Adding a listener to an already failed task * will result in that listener being called immediately. */ - def addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext = { - addJavaFriendlyTaskFailureListener(new TaskFailureListener { - override def onTaskFailure(context: TaskContext, error: Throwable): Unit = f(context, error) - }) - } + def addTaskFailureListener(listener: TaskFailureListener): TaskContext /** * The ID of the stage that this task belong to. diff --git a/core/src/main/scala/org/apache/spark/util/taskListeners.scala b/core/src/main/scala/org/apache/spark/util/taskListeners.scala index 1be31e88ab..fb3852a636 100644 --- a/core/src/main/scala/org/apache/spark/util/taskListeners.scala +++ b/core/src/main/scala/org/apache/spark/util/taskListeners.scala @@ -19,6 +19,8 @@ package org.apache.spark.util import java.util.EventListener +import scala.language.implicitConversions + import org.apache.spark.TaskContext import org.apache.spark.annotation.DeveloperApi @@ -32,6 +34,13 @@ trait TaskCompletionListener extends EventListener { def onTaskCompletion(context: TaskContext): Unit } +object TaskCompletionListener { + implicit def functionToTaskCompletionListener(f: TaskContext => Unit): TaskCompletionListener = + new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = f(context) + } +} + /** * :: DeveloperApi :: @@ -44,6 +53,14 @@ trait TaskFailureListener extends EventListener { def onTaskFailure(context: TaskContext, error: Throwable): Unit } +object TaskFailureListener { + implicit def functionToTaskFailureListener( + f: (TaskContext, Throwable) => Unit): TaskFailureListener = + new TaskFailureListener { + override def onTaskFailure(context: TaskContext, error: Throwable): Unit = f(context, error) + } +} + /** * Exception thrown when there is an exception in executing the callback in TaskCompletionListener. |