aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java/org/apache
diff options
context:
space:
mode:
authorRam Sriharsha <rsriharsha@hw11853.local>2015-05-12 13:35:12 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-05-12 13:35:12 -0700
commit595a67589a42f8025d3e5fd4da413b1faa2e14bf (patch)
tree54073754a09b6ff793ba03fd4711dfcb16c7ad42 /mllib/src/test/java/org/apache
parent5438f49ccf374fed16bc2b7fc1556e4c0095b14c (diff)
downloadspark-595a67589a42f8025d3e5fd4da413b1faa2e14bf.tar.gz
spark-595a67589a42f8025d3e5fd4da413b1faa2e14bf.tar.bz2
spark-595a67589a42f8025d3e5fd4da413b1faa2e14bf.zip
[SPARK-7015] [MLLIB] [WIP] Multiclass to Binary Reduction: One Against All
initial cut of one against all. test code is a scaffolding , not fully implemented. This WIP is to gather early feedback. Author: Ram Sriharsha <rsriharsha@hw11853.local> Closes #5830 from harsha2010/reduction and squashes the following commits: 5f4b495 [Ram Sriharsha] Fix Test 386e98b [Ram Sriharsha] Style fix 49b4a17 [Ram Sriharsha] Simplify the test 02279cc [Ram Sriharsha] Output Label Metadata in Prediction Col bc78032 [Ram Sriharsha] Code Review Updates 8ce4845 [Ram Sriharsha] Merge with Master 2a807be [Ram Sriharsha] Merge branch 'master' into reduction e21bfcc [Ram Sriharsha] Style Fix 5614f23 [Ram Sriharsha] Style Fix c75583a [Ram Sriharsha] Cleanup 7a5f136 [Ram Sriharsha] Fix TODOs 804826b [Ram Sriharsha] Merge with Master 1448a5f [Ram Sriharsha] Style Fix 6e47807 [Ram Sriharsha] Style Fix d63e46b [Ram Sriharsha] Incorporate Code Review Feedback ced68b5 [Ram Sriharsha] Refactor OneVsAll to implement Predictor 78fa82a [Ram Sriharsha] extra line 0dfa1fb [Ram Sriharsha] Fix inexhaustive match cases that may arise from UnresolvedAttribute a59a4f4 [Ram Sriharsha] @Experimental 4167234 [Ram Sriharsha] Merge branch 'master' into reduction 868a4fd [Ram Sriharsha] @Experimental 041d905 [Ram Sriharsha] Code Review Fixes df188d8 [Ram Sriharsha] Style fix 612ec48 [Ram Sriharsha] Style Fix 6ef43d3 [Ram Sriharsha] Prefer Unresolved Attribute to Option: Java APIs are cleaner 6bf6bff [Ram Sriharsha] Update OneHotEncoder to new API e29cb89 [Ram Sriharsha] Merge branch 'master' into reduction 1c7fa44 [Ram Sriharsha] Fix Tests ca83672 [Ram Sriharsha] Incorporate Code Review Feedback + Rename to OneVsRestClassifier 221beeed [Ram Sriharsha] Upgrade to use Copy method for cloning Base Classifiers 26f1ddb [Ram Sriharsha] Merge with SPARK-5956 API changes 9738744 [Ram Sriharsha] Merge branch 'master' into reduction 1a3e375 [Ram Sriharsha] More efficient Implementation: Use withColumn to generate label column dynamically 32e0189 [Ram Sriharsha] Restrict reduction to Margin Based Classifiers ff272da [Ram Sriharsha] Style fix 28771f5 [Ram Sriharsha] Add Tests for Multiclass to Binary Reduction b60f874 [Ram Sriharsha] Fix Style issues in Test 3191cdf [Ram Sriharsha] Remove this test, accidental commit 23f056c [Ram Sriharsha] Fix Headers for test 1b5e929 [Ram Sriharsha] Fix Style issues and add Header 8752863 [Ram Sriharsha] [SPARK-7015][MLLib][WIP] Multiclass to Binary Reduction: One Against All
Diffstat (limited to 'mllib/src/test/java/org/apache')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java85
1 files changed, 85 insertions, 0 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java
new file mode 100644
index 0000000000..40a90ae9de
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.reduction;
+
+import java.io.Serializable;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import static scala.collection.JavaConversions.seqAsJavaList;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.classification.LogisticRegression;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+
+public class JavaOneVsRestSuite implements Serializable {
+
+ private transient JavaSparkContext jsc;
+ private transient SQLContext jsql;
+ private transient DataFrame dataset;
+ private transient JavaRDD<LabeledPoint> datasetRDD;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaLOneVsRestSuite");
+ jsql = new SQLContext(jsc);
+ int nPoints = 3;
+
+ /**
+ * The following weights and xMean/xVariance are computed from iris dataset with lambda = 0.2.
+ * As a result, we are actually drawing samples from probability distribution of built model.
+ */
+ double[] weights = {
+ -0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
+ -0.16624, -0.84355, -0.048509, -0.301789, 4.170682 };
+
+ double[] xMean = {5.843, 3.057, 3.758, 1.199};
+ double[] xVariance = {0.6856, 0.1899, 3.116, 0.581};
+ List<LabeledPoint> points = seqAsJavaList(generateMultinomialLogisticInput(
+ weights, xMean, xVariance, true, nPoints, 42));
+ datasetRDD = jsc.parallelize(points, 2);
+ dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ @Test
+ public void oneVsRestDefaultParams() {
+ OneVsRest ova = new OneVsRest();
+ ova.setClassifier(new LogisticRegression());
+ Assert.assertEquals(ova.getLabelCol() , "label");
+ Assert.assertEquals(ova.getPredictionCol() , "prediction");
+ OneVsRestModel ovaModel = ova.fit(dataset);
+ DataFrame predictions = ovaModel.transform(dataset).select("label", "prediction");
+ predictions.collectAsList();
+ Assert.assertEquals(ovaModel.getLabelCol(), "label");
+ Assert.assertEquals(ovaModel.getPredictionCol() , "prediction");
+ }
+}