diff options
author | Holden Karau <holden@pigscanfly.ca> | 2015-08-14 11:22:10 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-08-14 11:22:10 -0700 |
commit | a7317ccdc20d001e5b7f5277b0535923468bfbc6 (patch) | |
tree | b232f9b383b6ce6c9c14b813f1c66308685eaf2b | |
parent | 7ecf0c46990c39df8aeddbd64ca33d01824bcc0a (diff) | |
download | spark-a7317ccdc20d001e5b7f5277b0535923468bfbc6.tar.gz spark-a7317ccdc20d001e5b7f5277b0535923468bfbc6.tar.bz2 spark-a7317ccdc20d001e5b7f5277b0535923468bfbc6.zip |
[SPARK-8744] [ML] Add a public constructor to StringIndexer
It would be helpful to allow users to pass a pre-computed index to create an indexer, rather than always going through StringIndexer to create the model.
Author: Holden Karau <holden@pigscanfly.ca>
Closes #7267 from holdenk/SPARK-8744-StringIndexerModel-should-have-public-constructor.
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala | 4 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala | 2 |
2 files changed, 5 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 9f6e7b6b6b..63475780a6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -102,10 +102,12 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod * This is a temporary fix for the case when target labels do not exist during prediction. */ @Experimental -class StringIndexerModel private[ml] ( +class StringIndexerModel ( override val uid: String, labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { + def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels) + private val labelToIndex: OpenHashMap[String, Double] = { val n = labels.length val map = new OpenHashMap[String, Double](n) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index fa918ce648..0b4c8ba71e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -30,7 +30,9 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { test("params") { ParamsSuite.checkParams(new StringIndexer) val model = new StringIndexerModel("indexer", Array("a", "b")) + val modelWithoutUid = new StringIndexerModel(Array("a", "b")) ParamsSuite.checkParams(model) + ParamsSuite.checkParams(modelWithoutUid) } test("StringIndexer") { |