From 508d59ea5e2865a5bfc6d380527b1ab91e0fddd5 Mon Sep 17 00:00:00 2001 From: John St John Date: Mon, 25 Sep 2017 15:22:20 -0700 Subject: Pull in slickmdc --- .../driver/core/database/MdcAsyncExecutor.scala | 46 ++++++++++++++++++++++ .../driver/core/database/MdcExecutionContext.scala | 28 +++++++++++++ 2 files changed, 74 insertions(+) create mode 100644 src/main/scala/xyz/driver/core/database/MdcAsyncExecutor.scala create mode 100644 src/main/scala/xyz/driver/core/database/MdcExecutionContext.scala (limited to 'src/main/scala/xyz/driver/core/database') diff --git a/src/main/scala/xyz/driver/core/database/MdcAsyncExecutor.scala b/src/main/scala/xyz/driver/core/database/MdcAsyncExecutor.scala new file mode 100644 index 0000000..ac8456e --- /dev/null +++ b/src/main/scala/xyz/driver/core/database/MdcAsyncExecutor.scala @@ -0,0 +1,46 @@ +package xyz.driver.core.database + +import java.util.concurrent._ +import java.util.concurrent.atomic.AtomicInteger + +import scala.concurrent._ +import com.typesafe.scalalogging.StrictLogging +import slick.util.AsyncExecutor + +/** Taken from the original Slick AsyncExecutor and simplified + * @see https://github.com/slick/slick/blob/3.1/slick/src/main/scala/slick/util/AsyncExecutor.scala + */ +object MdcAsyncExecutor extends StrictLogging { + + /** Create an AsyncExecutor with a fixed-size thread pool. + * + * @param name The name for the thread pool. + * @param numThreads The number of threads in the pool. + */ + def apply(name: String, numThreads: Int): AsyncExecutor = { + new AsyncExecutor { + val tf = new DaemonThreadFactory(name + "-") + + lazy val executionContext = { + new MdcExecutionContext(ExecutionContext.fromExecutor(Executors.newFixedThreadPool(numThreads, tf))) + } + + def close(): Unit = {} + } + } + + def default(name: String = "AsyncExecutor.default"): AsyncExecutor = apply(name, 20) + + private class DaemonThreadFactory(namePrefix: String) extends ThreadFactory { + private[this] val group = + Option(System.getSecurityManager).fold(Thread.currentThread.getThreadGroup)(_.getThreadGroup) + private[this] val threadNumber = new AtomicInteger(1) + + def newThread(r: Runnable): Thread = { + val t = new Thread(group, r, namePrefix + threadNumber.getAndIncrement, 0) + if (!t.isDaemon) t.setDaemon(true) + if (t.getPriority != Thread.NORM_PRIORITY) t.setPriority(Thread.NORM_PRIORITY) + t + } + } +} diff --git a/src/main/scala/xyz/driver/core/database/MdcExecutionContext.scala b/src/main/scala/xyz/driver/core/database/MdcExecutionContext.scala new file mode 100644 index 0000000..f08f16c --- /dev/null +++ b/src/main/scala/xyz/driver/core/database/MdcExecutionContext.scala @@ -0,0 +1,28 @@ +package xyz.driver.core.database + +import org.slf4j.MDC +import scala.concurrent.ExecutionContext + +/** + * Execution context proxy for propagating SLF4J diagnostic context from caller thread to execution thread. + */ +class MdcExecutionContext(executionContext: ExecutionContext) extends ExecutionContext { + override def execute(runnable: Runnable): Unit = { + val callerMdc = MDC.getCopyOfContextMap + executionContext.execute(new Runnable { + def run(): Unit = { + // copy caller thread diagnostic context to execution thread + // scalastyle:off + if (callerMdc != null) MDC.setContextMap(callerMdc) + try { + runnable.run() + } finally { + // the thread might be reused, so we clean up for the next use + MDC.clear() + } + } + }) + } + + override def reportFailure(cause: Throwable): Unit = executionContext.reportFailure(cause) +} -- cgit v1.2.3