aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java/org/apache
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-04-29 17:26:46 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-29 17:26:46 -0700
commit114bad606e7a17f980ea6c99e31c8ab0179fec2e (patch)
treef0a1a5f81f6f626412a9a526b72d0b6e5edf570c /mllib/src/test/java/org/apache
parent1fdfdb47b44315ff8ccb0ef92e56d3f2a070f1f1 (diff)
downloadspark-114bad606e7a17f980ea6c99e31c8ab0179fec2e.tar.gz
spark-114bad606e7a17f980ea6c99e31c8ab0179fec2e.tar.bz2
spark-114bad606e7a17f980ea6c99e31c8ab0179fec2e.zip
[SPARK-7176] [ML] Add validation functionality to Param
Main change: Added isValid field to Param. Modified all usages to use isValid when relevant. Added helper methods in ParamValidate. Also overrode Params.validate() in: * CrossValidator + model * Pipeline + model I made a few updates for the elastic net patch: * I changed "tol" to "convergenceTol" * I added some documentation This PR is Scala + Java only. Python will be in a follow-up PR. CC: mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #5740 from jkbradley/enforce-validate and squashes the following commits: ad9c6c1 [Joseph K. Bradley] re-generated sharedParams after merging with current master 76415e8 [Joseph K. Bradley] reverted convergenceTol to tol af62f4b [Joseph K. Bradley] Removed changes to SparkBuild, python linalg. Fixed test failures. Renamed ParamValidate to ParamValidators. Removed explicit type from ParamValidators calls where possible. bb2665a [Joseph K. Bradley] merged with elastic net pr ecda302 [Joseph K. Bradley] fix rat tests, plus add a little doc 6895dfc [Joseph K. Bradley] small cleanups 069ac6d [Joseph K. Bradley] many cleanups 928fb84 [Joseph K. Bradley] Maybe done a910ac7 [Joseph K. Bradley] still workin 6d60e2e [Joseph K. Bradley] Still workin b987319 [Joseph K. Bradley] Partly done with adding checks, but blocking on adding checking functionality to Param dbc9fb2 [Joseph K. Bradley] merged with master. enforcing Params.validate
Diffstat (limited to 'mllib/src/test/java/org/apache')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java66
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java63
2 files changed, 129 insertions, 0 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
new file mode 100644
index 0000000000..e7df10dfa6
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.param;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaSparkContext;
+
+/**
+ * Test Param and related classes in Java
+ */
+public class JavaParamsSuite {
+
+ private transient JavaSparkContext jsc;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaParamsSuite");
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ @Test
+ public void testParams() {
+ JavaTestParams testParams = new JavaTestParams();
+ Assert.assertEquals(testParams.getMyIntParam(), 1);
+ testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a");
+ Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0);
+ Assert.assertEquals(testParams.getMyStringParam(), "a");
+ }
+
+ @Test
+ public void testParamValidate() {
+ ParamValidators.gt(1.0);
+ ParamValidators.gtEq(1.0);
+ ParamValidators.lt(1.0);
+ ParamValidators.ltEq(1.0);
+ ParamValidators.inRange(0, 1, true, false);
+ ParamValidators.inRange(0, 1);
+ ParamValidators.inArray(Lists.newArrayList(0, 1, 3));
+ ParamValidators.inArray(Lists.newArrayList("a", "b"));
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
new file mode 100644
index 0000000000..8abe575610
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.param;
+
+import java.util.List;
+
+import com.google.common.collect.Lists;
+
+/**
+ * A subclass of Params for testing.
+ */
+public class JavaTestParams extends JavaParams {
+
+ public IntParam myIntParam;
+
+ public int getMyIntParam() { return (Integer)getOrDefault(myIntParam); }
+
+ public JavaTestParams setMyIntParam(int value) {
+ set(myIntParam, value); return this;
+ }
+
+ public DoubleParam myDoubleParam;
+
+ public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam); }
+
+ public JavaTestParams setMyDoubleParam(double value) {
+ set(myDoubleParam, value); return this;
+ }
+
+ public Param<String> myStringParam;
+
+ public String getMyStringParam() { return (String)getOrDefault(myStringParam); }
+
+ public JavaTestParams setMyStringParam(String value) {
+ set(myStringParam, value); return this;
+ }
+
+ public JavaTestParams() {
+ myIntParam = new IntParam(this, "myIntParam", "this is an int param", ParamValidators.gt(0));
+ myDoubleParam = new DoubleParam(this, "myDoubleParam", "this is a double param",
+ ParamValidators.inRange(0.0, 1.0));
+ List<String> validStrings = Lists.newArrayList("a", "b");
+ myStringParam = new Param<String>(this, "myStringParam", "this is a string param",
+ ParamValidators.inArray(validStrings));
+ setDefault(myIntParam, 1);
+ setDefault(myDoubleParam, 0.5);
+ }
+}