aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java/org/apache
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-11-06 14:51:03 -0800
committerJoseph K. Bradley <joseph@databricks.com>2015-11-06 14:51:03 -0800
commitc447c9d54603890db7399fb80adc9fae40b71f64 (patch)
tree0f8a339ee0b28a00944bea96879600315ab3ef17 /mllib/src/test/java/org/apache
parent3a652f691b220fada0286f8d0a562c5657973d4d (diff)
downloadspark-c447c9d54603890db7399fb80adc9fae40b71f64.tar.gz
spark-c447c9d54603890db7399fb80adc9fae40b71f64.tar.bz2
spark-c447c9d54603890db7399fb80adc9fae40b71f64.zip
[SPARK-11217][ML] save/load for non-meta estimators and transformers
This PR implements the default save/load for non-meta estimators and transformers using the JSON serialization of param values. The saved metadata includes: * class name * uid * timestamp * paramMap The save/load interface is similar to DataFrames. We use the current active context by default, which should be sufficient for most use cases. ~~~scala instance.save("path") instance.write.context(sqlContext).overwrite().save("path") Instance.load("path") ~~~ The param handling is different from the design doc. We didn't save default and user-set params separately, and when we load it back, all parameters are user-set. This does cause issues. But it also cause other issues if we modify the default params. TODOs: * [x] Java test * [ ] a follow-up PR to implement default save/load for all non-meta estimators and transformers cc jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #9454 from mengxr/SPARK-11217.
Diffstat (limited to 'mllib/src/test/java/org/apache')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java74
1 files changed, 74 insertions, 0 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
new file mode 100644
index 0000000000..c39538014b
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util;
+
+import java.io.File;
+import java.io.IOException;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.util.Utils;
+
+public class JavaDefaultReadWriteSuite {
+
+ JavaSparkContext jsc = null;
+ File tempDir = null;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local[2]", "JavaDefaultReadWriteSuite");
+ tempDir = Utils.createTempDir(
+ System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite");
+ }
+
+ @After
+ public void tearDown() {
+ if (jsc != null) {
+ jsc.stop();
+ jsc = null;
+ }
+ Utils.deleteRecursively(tempDir);
+ }
+
+ @Test
+ public void testDefaultReadWrite() throws IOException {
+ String uid = "my_params";
+ MyParams instance = new MyParams(uid);
+ instance.set(instance.intParam(), 2);
+ String outputPath = new File(tempDir, uid).getPath();
+ instance.save(outputPath);
+ try {
+ instance.save(outputPath);
+ Assert.fail(
+ "Write without overwrite enabled should fail if the output directory already exists.");
+ } catch (IOException e) {
+ // expected
+ }
+ SQLContext sqlContext = new SQLContext(jsc);
+ instance.write().context(sqlContext).overwrite().save(outputPath);
+ MyParams newInstance = MyParams.load(outputPath);
+ Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid());
+ Assert.assertEquals("Params should be preserved.",
+ 2, newInstance.getOrDefault(newInstance.intParam()));
+ }
+}