From 18350a57004eb87cafa9504ff73affab4b818e06 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 10 Nov 2015 11:36:43 -0800 Subject: [SPARK-11618][ML] Minor refactoring of basic ML import/export Refactoring * separated overwrite and param save logic in DefaultParamsWriter * added sparkVersion to DefaultParamsWriter CC: mengxr Author: Joseph K. Bradley Closes #9587 from jkbradley/logreg-io. --- .../scala/org/apache/spark/ml/util/ReadWrite.scala | 57 ++++++++++++---------- 1 file changed, 30 insertions(+), 27 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index ea790e0ddd..cbdf913ba8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -51,6 +51,9 @@ private[util] sealed trait BaseReadWrite { protected final def sqlContext: SQLContext = optionSQLContext.getOrElse { SQLContext.getOrCreate(SparkContext.getOrCreate()) } + + /** Returns the [[SparkContext]] underlying [[sqlContext]] */ + protected final def sc: SparkContext = sqlContext.sparkContext } /** @@ -58,7 +61,7 @@ private[util] sealed trait BaseReadWrite { */ @Experimental @Since("1.6.0") -abstract class Writer extends BaseReadWrite { +abstract class Writer extends BaseReadWrite with Logging { protected var shouldOverwrite: Boolean = false @@ -67,7 +70,29 @@ abstract class Writer extends BaseReadWrite { */ @Since("1.6.0") @throws[IOException]("If the input path already exists but overwrite is not enabled.") - def save(path: String): Unit + def save(path: String): Unit = { + val hadoopConf = sc.hadoopConfiguration + val fs = FileSystem.get(hadoopConf) + val p = new Path(path) + if (fs.exists(p)) { + if (shouldOverwrite) { + logInfo(s"Path $path already exists. It will be overwritten.") + // TODO: Revert back to the original content if save is not successful. + fs.delete(p, true) + } else { + throw new IOException( + s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.") + } + } + saveImpl(path) + } + + /** + * [[save()]] handles overwriting and then calls this method. Subclasses should override this + * method to implement the actual saving of the instance. + */ + @Since("1.6.0") + protected def saveImpl(path: String): Unit /** * Overwrites if the output path already exists. @@ -147,28 +172,9 @@ trait Readable[T] { * data (e.g., models with coefficients). * @param instance object to save */ -private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logging { - - /** - * Saves the ML component to the input path. - */ - override def save(path: String): Unit = { - val sc = sqlContext.sparkContext - - val hadoopConf = sc.hadoopConfiguration - val fs = FileSystem.get(hadoopConf) - val p = new Path(path) - if (fs.exists(p)) { - if (shouldOverwrite) { - logInfo(s"Path $path already exists. It will be overwritten.") - // TODO: Revert back to the original content if save is not successful. - fs.delete(p, true) - } else { - throw new IOException( - s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.") - } - } +private[ml] class DefaultParamsWriter(instance: Params) extends Writer { + override protected def saveImpl(path: String): Unit = { val uid = instance.uid val cls = instance.getClass.getName val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] @@ -177,6 +183,7 @@ private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logg }.toList val metadata = ("class" -> cls) ~ ("timestamp" -> System.currentTimeMillis()) ~ + ("sparkVersion" -> sc.version) ~ ("uid" -> uid) ~ ("paramMap" -> jsonParams) val metadataPath = new Path(path, "metadata").toString @@ -193,12 +200,8 @@ private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logg */ private[ml] class DefaultParamsReader[T] extends Reader[T] { - /** - * Loads the ML component from the input path. - */ override def load(path: String): T = { implicit val format = DefaultFormats - val sc = sqlContext.sparkContext val metadataPath = new Path(path, "metadata").toString val metadataStr = sc.textFile(metadataPath, 1).first() val metadata = parse(metadataStr) -- cgit v1.2.3