diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2015-05-21 22:59:45 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-05-21 22:59:45 -0700 |
commit | 2728c3df6690c2fcd4af3bd1c604c98ef6d509a5 (patch) | |
tree | b69cc8705b56e91ba27195ea508996425ebd5d6f /mllib/src/test/java/org/apache | |
parent | 8f11c6116bf8c7246682cbb2d6f27bf0f1531c6d (diff) | |
download | spark-2728c3df6690c2fcd4af3bd1c604c98ef6d509a5.tar.gz spark-2728c3df6690c2fcd4af3bd1c604c98ef6d509a5.tar.bz2 spark-2728c3df6690c2fcd4af3bd1c604c98ef6d509a5.zip |
[SPARK-7578] [ML] [DOC] User guide for spark.ml Normalizer, IDF, StandardScaler
Added user guide sections with code examples.
Also added small Java unit tests to test Java example in guide.
CC: mengxr
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #6127 from jkbradley/feature-guide-2 and squashes the following commits:
cd47f4b [Joseph K. Bradley] Updated based on code review
f16bcec [Joseph K. Bradley] Fixed merge issues and update Python examples print calls for Python 3
0a862f9 [Joseph K. Bradley] Added Normalizer, StandardScaler to ml-features doc, plus small Java unit tests
a21c2d6 [Joseph K. Bradley] Updated ml-features.md with IDF
Diffstat (limited to 'mllib/src/test/java/org/apache')
3 files changed, 153 insertions, 6 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java index 23463ab5fe..da22180563 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java @@ -63,17 +63,22 @@ public class JavaHashingTFSuite { new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) }); - DataFrame sentenceDataFrame = jsql.createDataFrame(jrdd, schema); - Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); - DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame); + DataFrame sentenceData = jsql.createDataFrame(jrdd, schema); + Tokenizer tokenizer = new Tokenizer() + .setInputCol("sentence") + .setOutputCol("words"); + DataFrame wordsData = tokenizer.transform(sentenceData); int numFeatures = 20; HashingTF hashingTF = new HashingTF() .setInputCol("words") - .setOutputCol("features") + .setOutputCol("rawFeatures") .setNumFeatures(numFeatures); - DataFrame featurized = hashingTF.transform(wordsDataFrame); - for (Row r : featurized.select("features", "words", "label").take(3)) { + DataFrame featurizedData = hashingTF.transform(wordsData); + IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features"); + IDFModel idfModel = idf.fit(featurizedData); + DataFrame rescaledData = idfModel.transform(featurizedData); + for (Row r : rescaledData.select("features", "label").take(3)) { Vector features = r.getAs(0); Assert.assertEquals(features.size(), numFeatures); } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java new file mode 100644 index 0000000000..d82f3b7e8c --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.util.List; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +public class JavaNormalizerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaNormalizerSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void normalizer() { + // The tests are to check Java compatibility. + List<VectorIndexerSuite.FeatureData> points = Lists.newArrayList( + new VectorIndexerSuite.FeatureData(Vectors.dense(0.0, -2.0)), + new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), + new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) + ); + DataFrame dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2), + VectorIndexerSuite.FeatureData.class); + Normalizer normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures"); + + // Normalize each Vector using $L^2$ norm. + DataFrame l2NormData = normalizer.transform(dataFrame, normalizer.p().w(2)); + l2NormData.count(); + + // Normalize each Vector using $L^\infty$ norm. + DataFrame lInfNormData = + normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); + lInfNormData.count(); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java new file mode 100644 index 0000000000..74eb2733f0 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.util.List; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +public class JavaStandardScalerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaStandardScalerSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void standardScaler() { + // The tests are to check Java compatibility. + List<VectorIndexerSuite.FeatureData> points = Lists.newArrayList( + new VectorIndexerSuite.FeatureData(Vectors.dense(0.0, -2.0)), + new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), + new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) + ); + DataFrame dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2), + VectorIndexerSuite.FeatureData.class); + StandardScaler scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false); + + // Compute summary statistics by fitting the StandardScaler + StandardScalerModel scalerModel = scaler.fit(dataFrame); + + // Normalize each feature to have unit standard deviation. + DataFrame scaledData = scalerModel.transform(dataFrame); + scaledData.count(); + } +} |