aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-11-20 09:55:53 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-20 09:55:53 -0800
commit9ace2e5c8d7fbd360a93bc5fc4eace64a697b44f (patch)
tree6107dc359814aadcef8640d2e7c4a0926e5b0a0c /mllib
parenta66142decee48bf5689fb7f4f33646d7bb1ac08d (diff)
downloadspark-9ace2e5c8d7fbd360a93bc5fc4eace64a697b44f.tar.gz
spark-9ace2e5c8d7fbd360a93bc5fc4eace64a697b44f.tar.bz2
spark-9ace2e5c8d7fbd360a93bc5fc4eace64a697b44f.zip
[SPARK-11852][ML] StandardScaler minor refactor
```withStd``` and ```withMean``` should be params of ```StandardScaler``` and ```StandardScalerModel```. Author: Yanbo Liang <ybliang8@gmail.com> Closes #9839 from yanboliang/standardScaler-refactor.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala60
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala11
2 files changed, 32 insertions, 39 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 6d545219eb..d76a9c6275 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -36,20 +36,30 @@ import org.apache.spark.sql.types.{StructField, StructType}
private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol {
/**
- * Centers the data with mean before scaling.
+ * Whether to center the data with mean before scaling.
* It will build a dense output, so this does not work on sparse input
* and will raise an exception.
* Default: false
* @group param
*/
- val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean")
+ val withMean: BooleanParam = new BooleanParam(this, "withMean",
+ "Whether to center data with mean")
+
+ /** @group getParam */
+ def getWithMean: Boolean = $(withMean)
/**
- * Scales the data to unit standard deviation.
+ * Whether to scale the data to unit standard deviation.
* Default: true
* @group param
*/
- val withStd: BooleanParam = new BooleanParam(this, "withStd", "Scale to unit standard deviation")
+ val withStd: BooleanParam = new BooleanParam(this, "withStd",
+ "Whether to scale the data to unit standard deviation")
+
+ /** @group getParam */
+ def getWithStd: Boolean = $(withStd)
+
+ setDefault(withMean -> false, withStd -> true)
}
/**
@@ -63,8 +73,6 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
def this() = this(Identifiable.randomUID("stdScal"))
- setDefault(withMean -> false, withStd -> true)
-
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
@@ -82,7 +90,7 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v }
val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd))
val scalerModel = scaler.fit(input)
- copyValues(new StandardScalerModel(uid, scalerModel).setParent(this))
+ copyValues(new StandardScalerModel(uid, scalerModel.std, scalerModel.mean).setParent(this))
}
override def transformSchema(schema: StructType): StructType = {
@@ -108,29 +116,19 @@ object StandardScaler extends DefaultParamsReadable[StandardScaler] {
/**
* :: Experimental ::
* Model fitted by [[StandardScaler]].
+ *
+ * @param std Standard deviation of the StandardScalerModel
+ * @param mean Mean of the StandardScalerModel
*/
@Experimental
class StandardScalerModel private[ml] (
override val uid: String,
- scaler: feature.StandardScalerModel)
+ val std: Vector,
+ val mean: Vector)
extends Model[StandardScalerModel] with StandardScalerParams with MLWritable {
import StandardScalerModel._
- /** Standard deviation of the StandardScalerModel */
- val std: Vector = scaler.std
-
- /** Mean of the StandardScalerModel */
- val mean: Vector = scaler.mean
-
- /** Whether to scale to unit standard deviation. */
- @Since("1.6.0")
- def getWithStd: Boolean = scaler.withStd
-
- /** Whether to center data with mean. */
- @Since("1.6.0")
- def getWithMean: Boolean = scaler.withMean
-
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
@@ -139,6 +137,7 @@ class StandardScalerModel private[ml] (
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
+ val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean))
val scale = udf { scaler.transform _ }
dataset.withColumn($(outputCol), scale(col($(inputCol))))
}
@@ -154,7 +153,7 @@ class StandardScalerModel private[ml] (
}
override def copy(extra: ParamMap): StandardScalerModel = {
- val copied = new StandardScalerModel(uid, scaler)
+ val copied = new StandardScalerModel(uid, std, mean)
copyValues(copied, extra).setParent(parent)
}
@@ -168,11 +167,11 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {
private[StandardScalerModel]
class StandardScalerModelWriter(instance: StandardScalerModel) extends MLWriter {
- private case class Data(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean)
+ private case class Data(std: Vector, mean: Vector)
override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
- val data = Data(instance.std, instance.mean, instance.getWithStd, instance.getWithMean)
+ val data = Data(instance.std, instance.mean)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
@@ -185,13 +184,10 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {
override def load(path: String): StandardScalerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val Row(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) =
- sqlContext.read.parquet(dataPath)
- .select("std", "mean", "withStd", "withMean")
- .head()
- // This is very likely to change in the future because withStd and withMean should be params.
- val oldModel = new feature.StandardScalerModel(std, mean, withStd, withMean)
- val model = new StandardScalerModel(metadata.uid, oldModel)
+ val Row(std: Vector, mean: Vector) = sqlContext.read.parquet(dataPath)
+ .select("std", "mean")
+ .head()
+ val model = new StandardScalerModel(metadata.uid, std, mean)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
index 49a4b2efe0..1eae125a52 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
@@ -70,8 +70,8 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
test("params") {
ParamsSuite.checkParams(new StandardScaler)
- val oldModel = new feature.StandardScalerModel(Vectors.dense(1.0), Vectors.dense(2.0))
- ParamsSuite.checkParams(new StandardScalerModel("empty", oldModel))
+ ParamsSuite.checkParams(new StandardScalerModel("empty",
+ Vectors.dense(1.0), Vectors.dense(2.0)))
}
test("Standardization with default parameter") {
@@ -126,13 +126,10 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
}
test("StandardScalerModel read/write") {
- val oldModel = new feature.StandardScalerModel(
- Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0), false, true)
- val instance = new StandardScalerModel("myStandardScalerModel", oldModel)
+ val instance = new StandardScalerModel("myStandardScalerModel",
+ Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0))
val newInstance = testDefaultReadWrite(instance)
assert(newInstance.std === instance.std)
assert(newInstance.mean === instance.mean)
- assert(newInstance.getWithStd === instance.getWithStd)
- assert(newInstance.getWithMean === instance.getWithMean)
}
}