aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorjrabary <Jaonary@gmail.com>2015-04-20 09:47:56 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-04-20 09:47:56 -0700
commit1be207078cef48c5935595969bf9f6b1ec1334ca (patch)
tree1dcc1770d108ac6b994527c4d95615695d92f9b1 /mllib/src
parent6fe690d5a8216ba7efde4b52e7a19fb00814341c (diff)
downloadspark-1be207078cef48c5935595969bf9f6b1ec1334ca.tar.gz
spark-1be207078cef48c5935595969bf9f6b1ec1334ca.tar.bz2
spark-1be207078cef48c5935595969bf9f6b1ec1334ca.zip
[SPARK-5924] Add the ability to specify withMean or withStd parameters with StandarScaler
The current implementation call the default constructor of mllib.feature.StandarScaler without the possibility to specify withMean or withStd options. Author: jrabary <Jaonary@gmail.com> Closes #4704 from jrabary/master and squashes the following commits: fae8568 [jrabary] style fix 8896b0e [jrabary] Comments fix ef96d73 [jrabary] style fix 8e52607 [jrabary] style fix edd9d48 [jrabary] Fix default param initialization 17e1a76 [jrabary] Fix default param initialization 298f405 [jrabary] Typo fix 45ed914 [jrabary] Add withMean and withStd params to StandarScaler
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala32
1 files changed, 28 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 1b102619b3..447851ec03 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -30,7 +30,22 @@ import org.apache.spark.sql.types.{StructField, StructType}
/**
* Params for [[StandardScaler]] and [[StandardScalerModel]].
*/
-private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol
+private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol {
+
+ /**
+ * False by default. Centers the data with mean before scaling.
+ * It will build a dense output, so this does not work on sparse input
+ * and will raise an exception.
+ * @group param
+ */
+ val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean")
+
+ /**
+ * True by default. Scales the data to unit standard deviation.
+ * @group param
+ */
+ val withStd: BooleanParam = new BooleanParam(this, "withStd", "Scale to unit standard deviation")
+}
/**
* :: AlphaComponent ::
@@ -40,18 +55,27 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with
@AlphaComponent
class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams {
+ setDefault(withMean -> false, withStd -> true)
+
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
-
+
+ /** @group setParam */
+ def setWithMean(value: Boolean): this.type = set(withMean, value)
+
+ /** @group setParam */
+ def setWithStd(value: Boolean): this.type = set(withStd, value)
+
override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = {
transformSchema(dataset.schema, paramMap, logging = true)
val map = extractParamMap(paramMap)
val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v }
- val scaler = new feature.StandardScaler().fit(input)
- val model = new StandardScalerModel(this, map, scaler)
+ val scaler = new feature.StandardScaler(withMean = map(withMean), withStd = map(withStd))
+ val scalerModel = scaler.fit(input)
+ val model = new StandardScalerModel(this, map, scalerModel)
Params.inheritValues(map, this, model)
model
}