aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-05-21 22:59:45 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-21 22:59:45 -0700
commit2728c3df6690c2fcd4af3bd1c604c98ef6d509a5 (patch)
treeb69cc8705b56e91ba27195ea508996425ebd5d6f /mllib
parent8f11c6116bf8c7246682cbb2d6f27bf0f1531c6d (diff)
downloadspark-2728c3df6690c2fcd4af3bd1c604c98ef6d509a5.tar.gz
spark-2728c3df6690c2fcd4af3bd1c604c98ef6d509a5.tar.bz2
spark-2728c3df6690c2fcd4af3bd1c604c98ef6d509a5.zip
[SPARK-7578] [ML] [DOC] User guide for spark.ml Normalizer, IDF, StandardScaler
Added user guide sections with code examples. Also added small Java unit tests to test Java example in guide. CC: mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #6127 from jkbradley/feature-guide-2 and squashes the following commits: cd47f4b [Joseph K. Bradley] Updated based on code review f16bcec [Joseph K. Bradley] Fixed merge issues and update Python examples print calls for Python 3 0a862f9 [Joseph K. Bradley] Added Normalizer, StandardScaler to ml-features doc, plus small Java unit tests a21c2d6 [Joseph K. Bradley] Updated ml-features.md with IDF
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java17
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java71
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java71
3 files changed, 153 insertions, 6 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
index 23463ab5fe..da22180563 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
@@ -63,17 +63,22 @@ public class JavaHashingTFSuite {
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
});
- DataFrame sentenceDataFrame = jsql.createDataFrame(jrdd, schema);
- Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words");
- DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame);
+ DataFrame sentenceData = jsql.createDataFrame(jrdd, schema);
+ Tokenizer tokenizer = new Tokenizer()
+ .setInputCol("sentence")
+ .setOutputCol("words");
+ DataFrame wordsData = tokenizer.transform(sentenceData);
int numFeatures = 20;
HashingTF hashingTF = new HashingTF()
.setInputCol("words")
- .setOutputCol("features")
+ .setOutputCol("rawFeatures")
.setNumFeatures(numFeatures);
- DataFrame featurized = hashingTF.transform(wordsDataFrame);
- for (Row r : featurized.select("features", "words", "label").take(3)) {
+ DataFrame featurizedData = hashingTF.transform(wordsData);
+ IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features");
+ IDFModel idfModel = idf.fit(featurizedData);
+ DataFrame rescaledData = idfModel.transform(featurizedData);
+ for (Row r : rescaledData.select("features", "label").take(3)) {
Vector features = r.getAs(0);
Assert.assertEquals(features.size(), numFeatures);
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
new file mode 100644
index 0000000000..d82f3b7e8c
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature;
+
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+
+public class JavaNormalizerSuite {
+ private transient JavaSparkContext jsc;
+ private transient SQLContext jsql;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaNormalizerSuite");
+ jsql = new SQLContext(jsc);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ @Test
+ public void normalizer() {
+ // The tests are to check Java compatibility.
+ List<VectorIndexerSuite.FeatureData> points = Lists.newArrayList(
+ new VectorIndexerSuite.FeatureData(Vectors.dense(0.0, -2.0)),
+ new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
+ new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
+ );
+ DataFrame dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2),
+ VectorIndexerSuite.FeatureData.class);
+ Normalizer normalizer = new Normalizer()
+ .setInputCol("features")
+ .setOutputCol("normFeatures");
+
+ // Normalize each Vector using $L^2$ norm.
+ DataFrame l2NormData = normalizer.transform(dataFrame, normalizer.p().w(2));
+ l2NormData.count();
+
+ // Normalize each Vector using $L^\infty$ norm.
+ DataFrame lInfNormData =
+ normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY));
+ lInfNormData.count();
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
new file mode 100644
index 0000000000..74eb2733f0
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature;
+
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+
+public class JavaStandardScalerSuite {
+ private transient JavaSparkContext jsc;
+ private transient SQLContext jsql;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaStandardScalerSuite");
+ jsql = new SQLContext(jsc);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ @Test
+ public void standardScaler() {
+ // The tests are to check Java compatibility.
+ List<VectorIndexerSuite.FeatureData> points = Lists.newArrayList(
+ new VectorIndexerSuite.FeatureData(Vectors.dense(0.0, -2.0)),
+ new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
+ new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
+ );
+ DataFrame dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2),
+ VectorIndexerSuite.FeatureData.class);
+ StandardScaler scaler = new StandardScaler()
+ .setInputCol("features")
+ .setOutputCol("scaledFeatures")
+ .setWithStd(true)
+ .setWithMean(false);
+
+ // Compute summary statistics by fitting the StandardScaler
+ StandardScalerModel scalerModel = scaler.fit(dataFrame);
+
+ // Normalize each feature to have unit standard deviation.
+ DataFrame scaledData = scalerModel.transform(dataFrame);
+ scaledData.count();
+ }
+}