diff options
Diffstat (limited to 'sql/core/src')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala | 14 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala | 47 |
2 files changed, 46 insertions, 15 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 362bf45d03..0f6292db62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -96,10 +96,7 @@ class SparkSession private( */ @transient private[sql] lazy val sharedState: SharedState = { - existingSharedState.getOrElse( - SparkSession.reflect[SharedState, SparkContext]( - SparkSession.sharedStateClassName(sparkContext.conf), - sparkContext)) + existingSharedState.getOrElse(new SharedState(sparkContext)) } /** @@ -913,16 +910,8 @@ object SparkSession { /** Reference to the root SparkSession. */ private val defaultSession = new AtomicReference[SparkSession] - private val HIVE_SHARED_STATE_CLASS_NAME = "org.apache.spark.sql.hive.HiveSharedState" private val HIVE_SESSION_STATE_CLASS_NAME = "org.apache.spark.sql.hive.HiveSessionState" - private def sharedStateClassName(conf: SparkConf): String = { - conf.get(CATALOG_IMPLEMENTATION) match { - case "hive" => HIVE_SHARED_STATE_CLASS_NAME - case "in-memory" => classOf[SharedState].getCanonicalName - } - } - private def sessionStateClassName(conf: SparkConf): String = { conf.get(CATALOG_IMPLEMENTATION) match { case "hive" => HIVE_SESSION_STATE_CLASS_NAME @@ -953,7 +942,6 @@ object SparkSession { private[spark] def hiveClassesArePresent: Boolean = { try { Utils.classForName(HIVE_SESSION_STATE_CLASS_NAME) - Utils.classForName(HIVE_SHARED_STATE_CLASS_NAME) Utils.classForName("org.apache.hadoop.hive.conf.HiveConf") true } catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index 54aee5e02b..6387f01506 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -17,7 +17,13 @@ package org.apache.spark.sql.internal -import org.apache.spark.SparkContext +import scala.reflect.ClassTag +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.internal.config._ import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, InMemoryCatalog} @@ -51,7 +57,11 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { /** * A catalog that interacts with external systems. */ - lazy val externalCatalog: ExternalCatalog = new InMemoryCatalog(sparkContext.hadoopConfiguration) + lazy val externalCatalog: ExternalCatalog = + SharedState.reflect[ExternalCatalog, SparkConf, Configuration]( + SharedState.externalCatalogClassName(sparkContext.conf), + sparkContext.conf, + sparkContext.hadoopConfiguration) /** * A classloader used to load all user-added jar. @@ -98,6 +108,39 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { } } +object SharedState { + + private val HIVE_EXTERNAL_CATALOG_CLASS_NAME = "org.apache.spark.sql.hive.HiveExternalCatalog" + + private def externalCatalogClassName(conf: SparkConf): String = { + conf.get(CATALOG_IMPLEMENTATION) match { + case "hive" => HIVE_EXTERNAL_CATALOG_CLASS_NAME + case "in-memory" => classOf[InMemoryCatalog].getCanonicalName + } + } + + /** + * Helper method to create an instance of [[T]] using a single-arg constructor that + * accepts an [[Arg1]] and an [[Arg2]]. + */ + private def reflect[T, Arg1 <: AnyRef, Arg2 <: AnyRef]( + className: String, + ctorArg1: Arg1, + ctorArg2: Arg2)( + implicit ctorArgTag1: ClassTag[Arg1], + ctorArgTag2: ClassTag[Arg2]): T = { + try { + val clazz = Utils.classForName(className) + val ctor = clazz.getDeclaredConstructor(ctorArgTag1.runtimeClass, ctorArgTag2.runtimeClass) + val args = Array[AnyRef](ctorArg1, ctorArg2) + ctor.newInstance(args: _*).asInstanceOf[T] + } catch { + case NonFatal(e) => + throw new IllegalArgumentException(s"Error while instantiating '$className':", e) + } + } +} + /** * URL class loader that exposes the `addURL` and `getURLs` methods in URLClassLoader. |