aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-08-20 14:47:04 -0700
committerXiangrui Meng <meng@databricks.com>2015-08-20 14:47:04 -0700
commit2a3d98aae285aba39786e9809f96de412a130f39 (patch)
treed25b77b2e81598a5be7ad36da1e55decb945b51f /mllib
parent907df2fce00d2cbc9fae371344f05f800e0d2726 (diff)
downloadspark-2a3d98aae285aba39786e9809f96de412a130f39.tar.gz
spark-2a3d98aae285aba39786e9809f96de412a130f39.tar.bz2
spark-2a3d98aae285aba39786e9809f96de412a130f39.zip
[SPARK-10138] [ML] move setters to MultilayerPerceptronClassifier and add Java test suite
Otherwise, setters do not return self type. jkbradley avulanov Author: Xiangrui Meng <meng@databricks.com> Closes #8342 from mengxr/SPARK-10138.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala54
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java74
2 files changed, 101 insertions, 27 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index ccca4ecc00..1e5b0bc445 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -42,9 +42,6 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams
ParamValidators.arrayLengthGt(1)
)
- /** @group setParam */
- def setLayers(value: Array[Int]): this.type = set(layers, value)
-
/** @group getParam */
final def getLayers: Array[Int] = $(layers)
@@ -61,33 +58,9 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams
"it is adjusted to the size of this data. Recommended size is between 10 and 1000",
ParamValidators.gt(0))
- /** @group setParam */
- def setBlockSize(value: Int): this.type = set(blockSize, value)
-
/** @group getParam */
final def getBlockSize: Int = $(blockSize)
- /**
- * Set the maximum number of iterations.
- * Default is 100.
- * @group setParam
- */
- def setMaxIter(value: Int): this.type = set(maxIter, value)
-
- /**
- * Set the convergence tolerance of iterations.
- * Smaller value will lead to higher accuracy with the cost of more iterations.
- * Default is 1E-4.
- * @group setParam
- */
- def setTol(value: Double): this.type = set(tol, value)
-
- /**
- * Set the seed for weights initialization.
- * @group setParam
- */
- def setSeed(value: Long): this.type = set(seed, value)
-
setDefault(maxIter -> 100, tol -> 1e-4, layers -> Array(1, 1), blockSize -> 128)
}
@@ -136,6 +109,33 @@ class MultilayerPerceptronClassifier(override val uid: String)
def this() = this(Identifiable.randomUID("mlpc"))
+ /** @group setParam */
+ def setLayers(value: Array[Int]): this.type = set(layers, value)
+
+ /** @group setParam */
+ def setBlockSize(value: Int): this.type = set(blockSize, value)
+
+ /**
+ * Set the maximum number of iterations.
+ * Default is 100.
+ * @group setParam
+ */
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+ /**
+ * Set the convergence tolerance of iterations.
+ * Smaller value will lead to higher accuracy with the cost of more iterations.
+ * Default is 1E-4.
+ * @group setParam
+ */
+ def setTol(value: Double): this.type = set(tol, value)
+
+ /**
+ * Set the seed for weights initialization.
+ * @group setParam
+ */
+ def setSeed(value: Long): this.type = set(seed, value)
+
override def copy(extra: ParamMap): MultilayerPerceptronClassifier = defaultCopy(extra)
/**
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
new file mode 100644
index 0000000000..ec6b4bf3c0
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification;
+
+import java.io.Serializable;
+import java.util.Arrays;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
+
+public class JavaMultilayerPerceptronClassifierSuite implements Serializable {
+
+ private transient JavaSparkContext jsc;
+ private transient SQLContext sqlContext;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
+ sqlContext = new SQLContext(jsc);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ sqlContext = null;
+ }
+
+ @Test
+ public void testMLPC() {
+ DataFrame dataFrame = sqlContext.createDataFrame(
+ jsc.parallelize(Arrays.asList(
+ new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
+ new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
+ new LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
+ new LabeledPoint(0.0, Vectors.dense(1.0, 1.0)))),
+ LabeledPoint.class);
+ MultilayerPerceptronClassifier mlpc = new MultilayerPerceptronClassifier()
+ .setLayers(new int[] {2, 5, 2})
+ .setBlockSize(1)
+ .setSeed(11L)
+ .setMaxIter(100);
+ MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame);
+ DataFrame result = model.transform(dataFrame);
+ Row[] predictionAndLabels = result.select("prediction", "label").collect();
+ for (Row r: predictionAndLabels) {
+ Assert.assertEquals((int) r.getDouble(0), (int) r.getDouble(1));
+ }
+ }
+}