aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-05-18 12:02:18 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-18 12:02:18 -0700
commit9c7e802a5a2b8cd3eb77642f84c54a8e976fc996 (patch)
tree2e3b7e367f57b64ef46733ee8b64aa258e58cca8 /mllib
parent56ede88485cfca90974425fcb603b257be47229b (diff)
downloadspark-9c7e802a5a2b8cd3eb77642f84c54a8e976fc996.tar.gz
spark-9c7e802a5a2b8cd3eb77642f84c54a8e976fc996.tar.bz2
spark-9c7e802a5a2b8cd3eb77642f84c54a8e976fc996.zip
[SPARK-7380] [MLLIB] pipeline stages should be copyable in Python
This PR makes pipeline stages in Python copyable and hence simplifies some implementations. It also includes the following changes: 1. Rename `paramMap` and `defaultParamMap` to `_paramMap` and `_defaultParamMap`, respectively. 2. Accept a list of param maps in `fit`. 3. Use parent uid and name to identify param. jkbradley Author: Xiangrui Meng <meng@databricks.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #6088 from mengxr/SPARK-7380 and squashes the following commits: 413c463 [Xiangrui Meng] remove unnecessary doc 4159f35 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7380 611c719 [Xiangrui Meng] fix python style 68862b8 [Xiangrui Meng] update _java_obj initialization 927ad19 [Xiangrui Meng] fix ml/tests.py 0138fc3 [Xiangrui Meng] update feature transformers and fix a bug in RegexTokenizer 9ca44fb [Xiangrui Meng] simplify Java wrappers and add tests c7d84ef [Xiangrui Meng] update ml/tests.py to test copy params 7e0d27f [Xiangrui Meng] merge master 46840fb [Xiangrui Meng] update wrappers b6db1ed [Xiangrui Meng] update all self.paramMap to self._paramMap 46cb6ed [Xiangrui Meng] merge master a163413 [Xiangrui Meng] fix style 1042e80 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7380 9630eae [Xiangrui Meng] fix Identifiable._randomUID 13bd70a [Xiangrui Meng] update ml/tests.py 64a536c [Xiangrui Meng] use _fit/_transform/_evaluate to simplify the impl 02abf13 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into copyable-python 66ce18c [Joseph K. Bradley] some cleanups before sending to Xiangrui 7431272 [Joseph K. Bradley] Rebased with master
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala6
3 files changed, 8 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
index 36d9e17eca..3f7f4f96fc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
@@ -61,7 +61,7 @@ class RegexTokenizer(override val uid: String)
* Default: 1, to avoid returning empty strings
* @group param
*/
- val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length (>= 0)",
+ val minTokenLength: IntParam = new IntParam(this, "minTokenLength", "minimum token length (>= 0)",
ParamValidators.gtEq(0))
/** @group setParam */
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 247e08be1b..c33b66d31c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -483,16 +483,15 @@ trait Params extends Identifiable with Serializable {
def copy(extra: ParamMap): Params = {
val that = this.getClass.getConstructor(classOf[String]).newInstance(uid)
copyValues(that, extra)
- that
}
/**
* Extracts the embedded default param values and user-supplied values, and then merges them with
* extra values from input into a flat param map, where the latter value is used if there exist
- * conflicts, i.e., with ordering: default param values < user-supplied values < extraParamMap.
+ * conflicts, i.e., with ordering: default param values < user-supplied values < extra.
*/
- final def extractParamMap(extraParamMap: ParamMap): ParamMap = {
- defaultParamMap ++ paramMap ++ extraParamMap
+ final def extractParamMap(extra: ParamMap): ParamMap = {
+ defaultParamMap ++ paramMap ++ extra
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala
index 1466976800..ddd34a5450 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala
@@ -23,15 +23,17 @@ import java.util.UUID
/**
* Trait for an object with an immutable unique ID that identifies itself and its derivatives.
*/
-trait Identifiable {
+private[spark] trait Identifiable {
/**
* An immutable unique ID for the object and its derivatives.
*/
val uid: String
+
+ override def toString: String = uid
}
-object Identifiable {
+private[spark] object Identifiable {
/**
* Returns a random UID that concatenates the given prefix, "_", and 12 random hex chars.