aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala47
2 files changed, 46 insertions, 15 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 362bf45d03..0f6292db62 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -96,10 +96,7 @@ class SparkSession private(
*/
@transient
private[sql] lazy val sharedState: SharedState = {
- existingSharedState.getOrElse(
- SparkSession.reflect[SharedState, SparkContext](
- SparkSession.sharedStateClassName(sparkContext.conf),
- sparkContext))
+ existingSharedState.getOrElse(new SharedState(sparkContext))
}
/**
@@ -913,16 +910,8 @@ object SparkSession {
/** Reference to the root SparkSession. */
private val defaultSession = new AtomicReference[SparkSession]
- private val HIVE_SHARED_STATE_CLASS_NAME = "org.apache.spark.sql.hive.HiveSharedState"
private val HIVE_SESSION_STATE_CLASS_NAME = "org.apache.spark.sql.hive.HiveSessionState"
- private def sharedStateClassName(conf: SparkConf): String = {
- conf.get(CATALOG_IMPLEMENTATION) match {
- case "hive" => HIVE_SHARED_STATE_CLASS_NAME
- case "in-memory" => classOf[SharedState].getCanonicalName
- }
- }
-
private def sessionStateClassName(conf: SparkConf): String = {
conf.get(CATALOG_IMPLEMENTATION) match {
case "hive" => HIVE_SESSION_STATE_CLASS_NAME
@@ -953,7 +942,6 @@ object SparkSession {
private[spark] def hiveClassesArePresent: Boolean = {
try {
Utils.classForName(HIVE_SESSION_STATE_CLASS_NAME)
- Utils.classForName(HIVE_SHARED_STATE_CLASS_NAME)
Utils.classForName("org.apache.hadoop.hive.conf.HiveConf")
true
} catch {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
index 54aee5e02b..6387f01506 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
@@ -17,7 +17,13 @@
package org.apache.spark.sql.internal
-import org.apache.spark.SparkContext
+import scala.reflect.ClassTag
+import scala.util.control.NonFatal
+
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.internal.config._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, InMemoryCatalog}
@@ -51,7 +57,11 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging {
/**
* A catalog that interacts with external systems.
*/
- lazy val externalCatalog: ExternalCatalog = new InMemoryCatalog(sparkContext.hadoopConfiguration)
+ lazy val externalCatalog: ExternalCatalog =
+ SharedState.reflect[ExternalCatalog, SparkConf, Configuration](
+ SharedState.externalCatalogClassName(sparkContext.conf),
+ sparkContext.conf,
+ sparkContext.hadoopConfiguration)
/**
* A classloader used to load all user-added jar.
@@ -98,6 +108,39 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging {
}
}
+object SharedState {
+
+ private val HIVE_EXTERNAL_CATALOG_CLASS_NAME = "org.apache.spark.sql.hive.HiveExternalCatalog"
+
+ private def externalCatalogClassName(conf: SparkConf): String = {
+ conf.get(CATALOG_IMPLEMENTATION) match {
+ case "hive" => HIVE_EXTERNAL_CATALOG_CLASS_NAME
+ case "in-memory" => classOf[InMemoryCatalog].getCanonicalName
+ }
+ }
+
+ /**
+ * Helper method to create an instance of [[T]] using a single-arg constructor that
+ * accepts an [[Arg1]] and an [[Arg2]].
+ */
+ private def reflect[T, Arg1 <: AnyRef, Arg2 <: AnyRef](
+ className: String,
+ ctorArg1: Arg1,
+ ctorArg2: Arg2)(
+ implicit ctorArgTag1: ClassTag[Arg1],
+ ctorArgTag2: ClassTag[Arg2]): T = {
+ try {
+ val clazz = Utils.classForName(className)
+ val ctor = clazz.getDeclaredConstructor(ctorArgTag1.runtimeClass, ctorArgTag2.runtimeClass)
+ val args = Array[AnyRef](ctorArg1, ctorArg2)
+ ctor.newInstance(args: _*).asInstanceOf[T]
+ } catch {
+ case NonFatal(e) =>
+ throw new IllegalArgumentException(s"Error while instantiating '$className':", e)
+ }
+ }
+}
+
/**
* URL class loader that exposes the `addURL` and `getURLs` methods in URLClassLoader.