diff options
author | Xiangrui Meng <meng@databricks.com> | 2015-11-06 14:51:03 -0800 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-11-06 14:51:03 -0800 |
commit | c447c9d54603890db7399fb80adc9fae40b71f64 (patch) | |
tree | 0f8a339ee0b28a00944bea96879600315ab3ef17 /mllib/src/test/java/org | |
parent | 3a652f691b220fada0286f8d0a562c5657973d4d (diff) | |
download | spark-c447c9d54603890db7399fb80adc9fae40b71f64.tar.gz spark-c447c9d54603890db7399fb80adc9fae40b71f64.tar.bz2 spark-c447c9d54603890db7399fb80adc9fae40b71f64.zip |
[SPARK-11217][ML] save/load for non-meta estimators and transformers
This PR implements the default save/load for non-meta estimators and transformers using the JSON serialization of param values. The saved metadata includes:
* class name
* uid
* timestamp
* paramMap
The save/load interface is similar to DataFrames. We use the current active context by default, which should be sufficient for most use cases.
~~~scala
instance.save("path")
instance.write.context(sqlContext).overwrite().save("path")
Instance.load("path")
~~~
The param handling is different from the design doc. We didn't save default and user-set params separately, and when we load it back, all parameters are user-set. This does cause issues. But it also cause other issues if we modify the default params.
TODOs:
* [x] Java test
* [ ] a follow-up PR to implement default save/load for all non-meta estimators and transformers
cc jkbradley
Author: Xiangrui Meng <meng@databricks.com>
Closes #9454 from mengxr/SPARK-11217.
Diffstat (limited to 'mllib/src/test/java/org')
-rw-r--r-- | mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java | 74 |
1 files changed, 74 insertions, 0 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java new file mode 100644 index 0000000000..c39538014b --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util; + +import java.io.File; +import java.io.IOException; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.util.Utils; + +public class JavaDefaultReadWriteSuite { + + JavaSparkContext jsc = null; + File tempDir = null; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local[2]", "JavaDefaultReadWriteSuite"); + tempDir = Utils.createTempDir( + System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite"); + } + + @After + public void tearDown() { + if (jsc != null) { + jsc.stop(); + jsc = null; + } + Utils.deleteRecursively(tempDir); + } + + @Test + public void testDefaultReadWrite() throws IOException { + String uid = "my_params"; + MyParams instance = new MyParams(uid); + instance.set(instance.intParam(), 2); + String outputPath = new File(tempDir, uid).getPath(); + instance.save(outputPath); + try { + instance.save(outputPath); + Assert.fail( + "Write without overwrite enabled should fail if the output directory already exists."); + } catch (IOException e) { + // expected + } + SQLContext sqlContext = new SQLContext(jsc); + instance.write().context(sqlContext).overwrite().save(outputPath); + MyParams newInstance = MyParams.load(outputPath); + Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid()); + Assert.assertEquals("Params should be preserved.", + 2, newInstance.getOrDefault(newInstance.intParam())); + } +} |