aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala72
1 files changed, 66 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index 1b494ec8b1..24d964fae8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -17,11 +17,14 @@
package org.apache.spark.ml.feature
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
-import org.apache.spark.ml.param.{ParamMap, DoubleParam, Params}
-import org.apache.spark.ml.util.Identifiable
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.param.{DoubleParam, ParamMap, Params}
+import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
import org.apache.spark.mllib.stat.Statistics
import org.apache.spark.sql._
@@ -85,7 +88,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H
*/
@Experimental
class MinMaxScaler(override val uid: String)
- extends Estimator[MinMaxScalerModel] with MinMaxScalerParams {
+ extends Estimator[MinMaxScalerModel] with MinMaxScalerParams with Writable {
def this() = this(Identifiable.randomUID("minMaxScal"))
@@ -115,6 +118,19 @@ class MinMaxScaler(override val uid: String)
}
override def copy(extra: ParamMap): MinMaxScaler = defaultCopy(extra)
+
+ @Since("1.6.0")
+ override def write: Writer = new DefaultParamsWriter(this)
+}
+
+@Since("1.6.0")
+object MinMaxScaler extends Readable[MinMaxScaler] {
+
+ @Since("1.6.0")
+ override def read: Reader[MinMaxScaler] = new DefaultParamsReader
+
+ @Since("1.6.0")
+ override def load(path: String): MinMaxScaler = super.load(path)
}
/**
@@ -131,7 +147,9 @@ class MinMaxScalerModel private[ml] (
override val uid: String,
val originalMin: Vector,
val originalMax: Vector)
- extends Model[MinMaxScalerModel] with MinMaxScalerParams {
+ extends Model[MinMaxScalerModel] with MinMaxScalerParams with Writable {
+
+ import MinMaxScalerModel._
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
@@ -175,4 +193,46 @@ class MinMaxScalerModel private[ml] (
val copied = new MinMaxScalerModel(uid, originalMin, originalMax)
copyValues(copied, extra).setParent(parent)
}
+
+ @Since("1.6.0")
+ override def write: Writer = new MinMaxScalerModelWriter(this)
+}
+
+@Since("1.6.0")
+object MinMaxScalerModel extends Readable[MinMaxScalerModel] {
+
+ private[MinMaxScalerModel]
+ class MinMaxScalerModelWriter(instance: MinMaxScalerModel) extends Writer {
+
+ private case class Data(originalMin: Vector, originalMax: Vector)
+
+ override protected def saveImpl(path: String): Unit = {
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ val data = new Data(instance.originalMin, instance.originalMax)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class MinMaxScalerModelReader extends Reader[MinMaxScalerModel] {
+
+ private val className = "org.apache.spark.ml.feature.MinMaxScalerModel"
+
+ override def load(path: String): MinMaxScalerModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val dataPath = new Path(path, "data").toString
+ val Row(originalMin: Vector, originalMax: Vector) = sqlContext.read.parquet(dataPath)
+ .select("originalMin", "originalMax")
+ .head()
+ val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
+
+ @Since("1.6.0")
+ override def read: Reader[MinMaxScalerModel] = new MinMaxScalerModelReader
+
+ @Since("1.6.0")
+ override def load(path: String): MinMaxScalerModel = super.load(path)
}