aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-11-10 11:36:43 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-10 11:36:43 -0800
commit18350a57004eb87cafa9504ff73affab4b818e06 (patch)
tree5765c5c5e49b356640237cb1b92ed3a2a64ca211 /mllib
parentf14e95115c0939a77ebcb00209696a87fd651ff9 (diff)
downloadspark-18350a57004eb87cafa9504ff73affab4b818e06.tar.gz
spark-18350a57004eb87cafa9504ff73affab4b818e06.tar.bz2
spark-18350a57004eb87cafa9504ff73affab4b818e06.zip
[SPARK-11618][ML] Minor refactoring of basic ML import/export
Refactoring * separated overwrite and param save logic in DefaultParamsWriter * added sparkVersion to DefaultParamsWriter CC: mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #9587 from jkbradley/logreg-io.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala57
1 files changed, 30 insertions, 27 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index ea790e0ddd..cbdf913ba8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -51,6 +51,9 @@ private[util] sealed trait BaseReadWrite {
protected final def sqlContext: SQLContext = optionSQLContext.getOrElse {
SQLContext.getOrCreate(SparkContext.getOrCreate())
}
+
+ /** Returns the [[SparkContext]] underlying [[sqlContext]] */
+ protected final def sc: SparkContext = sqlContext.sparkContext
}
/**
@@ -58,7 +61,7 @@ private[util] sealed trait BaseReadWrite {
*/
@Experimental
@Since("1.6.0")
-abstract class Writer extends BaseReadWrite {
+abstract class Writer extends BaseReadWrite with Logging {
protected var shouldOverwrite: Boolean = false
@@ -67,7 +70,29 @@ abstract class Writer extends BaseReadWrite {
*/
@Since("1.6.0")
@throws[IOException]("If the input path already exists but overwrite is not enabled.")
- def save(path: String): Unit
+ def save(path: String): Unit = {
+ val hadoopConf = sc.hadoopConfiguration
+ val fs = FileSystem.get(hadoopConf)
+ val p = new Path(path)
+ if (fs.exists(p)) {
+ if (shouldOverwrite) {
+ logInfo(s"Path $path already exists. It will be overwritten.")
+ // TODO: Revert back to the original content if save is not successful.
+ fs.delete(p, true)
+ } else {
+ throw new IOException(
+ s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.")
+ }
+ }
+ saveImpl(path)
+ }
+
+ /**
+ * [[save()]] handles overwriting and then calls this method. Subclasses should override this
+ * method to implement the actual saving of the instance.
+ */
+ @Since("1.6.0")
+ protected def saveImpl(path: String): Unit
/**
* Overwrites if the output path already exists.
@@ -147,28 +172,9 @@ trait Readable[T] {
* data (e.g., models with coefficients).
* @param instance object to save
*/
-private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logging {
-
- /**
- * Saves the ML component to the input path.
- */
- override def save(path: String): Unit = {
- val sc = sqlContext.sparkContext
-
- val hadoopConf = sc.hadoopConfiguration
- val fs = FileSystem.get(hadoopConf)
- val p = new Path(path)
- if (fs.exists(p)) {
- if (shouldOverwrite) {
- logInfo(s"Path $path already exists. It will be overwritten.")
- // TODO: Revert back to the original content if save is not successful.
- fs.delete(p, true)
- } else {
- throw new IOException(
- s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.")
- }
- }
+private[ml] class DefaultParamsWriter(instance: Params) extends Writer {
+ override protected def saveImpl(path: String): Unit = {
val uid = instance.uid
val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
@@ -177,6 +183,7 @@ private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logg
}.toList
val metadata = ("class" -> cls) ~
("timestamp" -> System.currentTimeMillis()) ~
+ ("sparkVersion" -> sc.version) ~
("uid" -> uid) ~
("paramMap" -> jsonParams)
val metadataPath = new Path(path, "metadata").toString
@@ -193,12 +200,8 @@ private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logg
*/
private[ml] class DefaultParamsReader[T] extends Reader[T] {
- /**
- * Loads the ML component from the input path.
- */
override def load(path: String): T = {
implicit val format = DefaultFormats
- val sc = sqlContext.sparkContext
val metadataPath = new Path(path, "metadata").toString
val metadataStr = sc.textFile(metadataPath, 1).first()
val metadata = parse(metadataStr)