aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala41
1 files changed, 30 insertions, 11 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index 90b8d7df7b..1582a73ea0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -40,28 +40,41 @@ import org.apache.spark.util.Utils
* Trait for [[MLWriter]] and [[MLReader]].
*/
private[util] sealed trait BaseReadWrite {
- private var optionSQLContext: Option[SQLContext] = None
+ private var optionSparkSession: Option[SparkSession] = None
/**
- * Sets the SQL context to use for saving/loading.
+ * Sets the Spark SQLContext to use for saving/loading.
*/
@Since("1.6.0")
+ @deprecated("Use session instead", "2.0.0")
def context(sqlContext: SQLContext): this.type = {
- optionSQLContext = Option(sqlContext)
+ optionSparkSession = Option(sqlContext.sparkSession)
this
}
/**
- * Returns the user-specified SQL context or the default.
+ * Sets the Spark Session to use for saving/loading.
*/
- protected final def sqlContext: SQLContext = {
- if (optionSQLContext.isEmpty) {
- optionSQLContext = Some(SQLContext.getOrCreate(SparkContext.getOrCreate()))
+ @Since("2.0.0")
+ def session(sparkSession: SparkSession): this.type = {
+ optionSparkSession = Option(sparkSession)
+ this
+ }
+
+ /**
+ * Returns the user-specified Spark Session or the default.
+ */
+ protected final def sparkSession: SparkSession = {
+ if (optionSparkSession.isEmpty) {
+ optionSparkSession = Some(SparkSession.builder().getOrCreate())
}
- optionSQLContext.get
+ optionSparkSession.get
}
- protected final def sparkSession: SparkSession = sqlContext.sparkSession
+ /**
+ * Returns the user-specified SQL context or the default.
+ */
+ protected final def sqlContext: SQLContext = sparkSession.sqlContext
/** Returns the underlying [[SparkContext]]. */
protected final def sc: SparkContext = sparkSession.sparkContext
@@ -118,7 +131,10 @@ abstract class MLWriter extends BaseReadWrite with Logging {
}
// override for Java compatibility
- override def context(sqlContext: SQLContext): this.type = super.context(sqlContext)
+ override def session(sparkSession: SparkSession): this.type = super.session(sparkSession)
+
+ // override for Java compatibility
+ override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession)
}
/**
@@ -180,7 +196,10 @@ abstract class MLReader[T] extends BaseReadWrite {
def load(path: String): T
// override for Java compatibility
- override def context(sqlContext: SQLContext): this.type = super.context(sqlContext)
+ override def session(sparkSession: SparkSession): this.type = super.session(sparkSession)
+
+ // override for Java compatibility
+ override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession)
}
/**