aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHolden Karau <holden@pigscanfly.ca>2015-08-14 11:22:10 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-14 11:22:10 -0700
commita7317ccdc20d001e5b7f5277b0535923468bfbc6 (patch)
treeb232f9b383b6ce6c9c14b813f1c66308685eaf2b
parent7ecf0c46990c39df8aeddbd64ca33d01824bcc0a (diff)
downloadspark-a7317ccdc20d001e5b7f5277b0535923468bfbc6.tar.gz
spark-a7317ccdc20d001e5b7f5277b0535923468bfbc6.tar.bz2
spark-a7317ccdc20d001e5b7f5277b0535923468bfbc6.zip
[SPARK-8744] [ML] Add a public constructor to StringIndexer
It would be helpful to allow users to pass a pre-computed index to create an indexer, rather than always going through StringIndexer to create the model. Author: Holden Karau <holden@pigscanfly.ca> Closes #7267 from holdenk/SPARK-8744-StringIndexerModel-should-have-public-constructor.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala2
2 files changed, 5 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 9f6e7b6b6b..63475780a6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -102,10 +102,12 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
* This is a temporary fix for the case when target labels do not exist during prediction.
*/
@Experimental
-class StringIndexerModel private[ml] (
+class StringIndexerModel (
override val uid: String,
labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase {
+ def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels)
+
private val labelToIndex: OpenHashMap[String, Double] = {
val n = labels.length
val map = new OpenHashMap[String, Double](n)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index fa918ce648..0b4c8ba71e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -30,7 +30,9 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
test("params") {
ParamsSuite.checkParams(new StringIndexer)
val model = new StringIndexerModel("indexer", Array("a", "b"))
+ val modelWithoutUid = new StringIndexerModel(Array("a", "b"))
ParamsSuite.checkParams(model)
+ ParamsSuite.checkParams(modelWithoutUid)
}
test("StringIndexer") {