aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-02-22 12:59:50 +0200
committerNick Pentreath <nick.pentreath@gmail.com>2016-02-22 12:59:50 +0200
commit40e6d40fe79ce45d511e049133d2f30a2963740b (patch)
treeaac6440fc376cb836b96e591c9926910201ed996 /mllib
parente298ac91e3f6177c6da83e2d8ee994d9037466da (diff)
downloadspark-40e6d40fe79ce45d511e049133d2f30a2963740b.tar.gz
spark-40e6d40fe79ce45d511e049133d2f30a2963740b.tar.bz2
spark-40e6d40fe79ce45d511e049133d2f30a2963740b.zip
[SPARK-13334][ML] ML KMeansModel / BisectingKMeansModel / QuantileDiscretizer should set parent
ML ```KMeansModel / BisectingKMeansModel / QuantileDiscretizer``` should set parent. cc mengxr Author: Yanbo Liang <ybliang8@gmail.com> Closes #11214 from yanboliang/spark-13334.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala1
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala1
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala4
6 files changed, 8 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index 0b47cbbac8..45d293bc69 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -185,7 +185,7 @@ class BisectingKMeans @Since("2.0.0") (
.setSeed($(seed))
val parentModel = bkm.run(rdd)
val model = new BisectingKMeansModel(uid, parentModel)
- copyValues(model)
+ copyValues(model.setParent(this))
}
@Since("2.0.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index dc6d5d9280..b2292e20e2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -250,7 +250,7 @@ class KMeans @Since("1.5.0") (
.setEpsilon($(tol))
val parentModel = algo.run(rdd)
val model = new KMeansModel(uid, parentModel)
- copyValues(model)
+ copyValues(model.setParent(this))
}
@Since("1.5.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 2a294d3881..1f4cca1233 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -95,7 +95,7 @@ final class QuantileDiscretizer(override val uid: String)
val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1)
val splits = QuantileDiscretizer.getSplits(candidates)
val bucketizer = new Bucketizer(uid).setSplits(splits)
- copyValues(bucketizer)
+ copyValues(bucketizer.setParent(this))
}
override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index b26571eb9f..fc4a4add5d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -81,5 +81,6 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(clusters.size === k)
assert(clusters === Set(0, 1, 2, 3, 4))
assert(model.computeCost(dataset) < 0.1)
+ assert(model.hasParent)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index 2724e51f31..e5357ba8e2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -97,6 +97,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(clusters.size === k)
assert(clusters === Set(0, 1, 2, 3, 4))
assert(model.computeCost(dataset) < 0.1)
+ assert(model.hasParent)
}
test("read/write") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index 4fde42972f..6a2c601bbe 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -94,7 +94,9 @@ private object QuantileDiscretizerSuite extends SparkFunSuite {
val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input")
val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result")
.setNumBuckets(numBucket).setSeed(1)
- val result = discretizer.fit(df).transform(df)
+ val model = discretizer.fit(df)
+ assert(model.hasParent)
+ val result = model.transform(df)
val transformedFeatures = result.select("result").collect()
.map { case Row(transformedFeature: Double) => transformedFeature }