aboutsummaryrefslogtreecommitdiff
path: root/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala
diff options
context:
space:
mode:
Diffstat (limited to 'repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala')
-rw-r--r--repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala24
1 files changed, 10 insertions, 14 deletions
diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala
index bd853f1522..8e381ff6ae 100644
--- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala
+++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala
@@ -23,8 +23,8 @@ import scala.tools.nsc.GenericRunnerSettings
import org.apache.spark._
import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
import org.apache.spark.util.Utils
-import org.apache.spark.sql.SQLContext
object Main extends Logging {
@@ -35,7 +35,7 @@ object Main extends Logging {
val outputDir = Utils.createTempDir(root = rootDir, namePrefix = "repl")
var sparkContext: SparkContext = _
- var sqlContext: SQLContext = _
+ var sparkSession: SparkSession = _
// this is a public var because tests reset it.
var interp: SparkILoop = _
@@ -92,19 +92,15 @@ object Main extends Logging {
sparkContext
}
- def createSQLContext(): SQLContext = {
- val name = "org.apache.spark.sql.hive.HiveContext"
- val loader = Utils.getContextOrSparkClassLoader
- try {
- sqlContext = loader.loadClass(name).getConstructor(classOf[SparkContext])
- .newInstance(sparkContext).asInstanceOf[SQLContext]
- logInfo("Created sql context (with Hive support)..")
- } catch {
- case _: java.lang.ClassNotFoundException | _: java.lang.NoClassDefFoundError =>
- sqlContext = new SQLContext(sparkContext)
- logInfo("Created sql context..")
+ def createSparkSession(): SparkSession = {
+ if (SparkSession.hiveClassesArePresent) {
+ sparkSession = SparkSession.withHiveSupport(sparkContext)
+ logInfo("Created Spark session with Hive support")
+ } else {
+ sparkSession = new SparkSession(sparkContext)
+ logInfo("Created Spark session")
}
- sqlContext
+ sparkSession
}
}