aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala14
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala2
-rw-r--r--python/pyspark/ml/classification.py15
-rw-r--r--python/pyspark/ml/clustering.py162
-rw-r--r--python/pyspark/ml/regression.py16
-rwxr-xr-xpython/pyspark/ml/tests.py32
16 files changed, 256 insertions, 47 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index f58efd36a1..d07b4adebb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -648,7 +648,7 @@ class LogisticRegression @Since("1.2.0") (
$(labelCol),
$(featuresCol),
objectiveHistory)
- model.setSummary(logRegSummary)
+ model.setSummary(Some(logRegSummary))
} else {
model
}
@@ -790,9 +790,9 @@ class LogisticRegressionModel private[spark] (
}
}
- private[classification] def setSummary(
- summary: LogisticRegressionTrainingSummary): this.type = {
- this.trainingSummary = Some(summary)
+ private[classification]
+ def setSummary(summary: Option[LogisticRegressionTrainingSummary]): this.type = {
+ this.trainingSummary = summary
this
}
@@ -887,8 +887,7 @@ class LogisticRegressionModel private[spark] (
override def copy(extra: ParamMap): LogisticRegressionModel = {
val newModel = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector,
numClasses, isMultinomial), extra)
- if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
- newModel.setParent(parent)
+ newModel.setSummary(trainingSummary).setParent(parent)
}
override protected def raw2prediction(rawPrediction: Vector): Double = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index f8a606d60b..e6ca3aedff 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -95,8 +95,7 @@ class BisectingKMeansModel private[ml] (
@Since("2.0.0")
override def copy(extra: ParamMap): BisectingKMeansModel = {
val copied = copyValues(new BisectingKMeansModel(uid, parentModel), extra)
- if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
- copied.setParent(this.parent)
+ copied.setSummary(trainingSummary).setParent(this.parent)
}
@Since("2.0.0")
@@ -132,8 +131,8 @@ class BisectingKMeansModel private[ml] (
private var trainingSummary: Option[BisectingKMeansSummary] = None
- private[clustering] def setSummary(summary: BisectingKMeansSummary): this.type = {
- this.trainingSummary = Some(summary)
+ private[clustering] def setSummary(summary: Option[BisectingKMeansSummary]): this.type = {
+ this.trainingSummary = summary
this
}
@@ -265,7 +264,7 @@ class BisectingKMeans @Since("2.0.0") (
val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this))
val summary = new BisectingKMeansSummary(
model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
- model.setSummary(summary)
+ model.setSummary(Some(summary))
instr.logSuccess(model)
model
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index c6035cc4c9..92d0b7d085 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -90,8 +90,7 @@ class GaussianMixtureModel private[ml] (
@Since("2.0.0")
override def copy(extra: ParamMap): GaussianMixtureModel = {
val copied = copyValues(new GaussianMixtureModel(uid, weights, gaussians), extra)
- if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
- copied.setParent(this.parent)
+ copied.setSummary(trainingSummary).setParent(this.parent)
}
@Since("2.0.0")
@@ -150,8 +149,8 @@ class GaussianMixtureModel private[ml] (
private var trainingSummary: Option[GaussianMixtureSummary] = None
- private[clustering] def setSummary(summary: GaussianMixtureSummary): this.type = {
- this.trainingSummary = Some(summary)
+ private[clustering] def setSummary(summary: Option[GaussianMixtureSummary]): this.type = {
+ this.trainingSummary = summary
this
}
@@ -340,7 +339,7 @@ class GaussianMixture @Since("2.0.0") (
.setParent(this)
val summary = new GaussianMixtureSummary(model.transform(dataset),
$(predictionCol), $(probabilityCol), $(featuresCol), $(k))
- model.setSummary(summary)
+ model.setSummary(Some(summary))
instr.logNumFeatures(model.gaussians.head.mean.size)
instr.logSuccess(model)
model
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 26505b4cc1..152bd13b7a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -110,8 +110,7 @@ class KMeansModel private[ml] (
@Since("1.5.0")
override def copy(extra: ParamMap): KMeansModel = {
val copied = copyValues(new KMeansModel(uid, parentModel), extra)
- if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
- copied.setParent(this.parent)
+ copied.setSummary(trainingSummary).setParent(this.parent)
}
/** @group setParam */
@@ -165,8 +164,8 @@ class KMeansModel private[ml] (
private var trainingSummary: Option[KMeansSummary] = None
- private[clustering] def setSummary(summary: KMeansSummary): this.type = {
- this.trainingSummary = Some(summary)
+ private[clustering] def setSummary(summary: Option[KMeansSummary]): this.type = {
+ this.trainingSummary = summary
this
}
@@ -325,7 +324,7 @@ class KMeans @Since("1.5.0") (
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
val summary = new KMeansSummary(
model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
- model.setSummary(summary)
+ model.setSummary(Some(summary))
instr.logSuccess(model)
model
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 736fd3b9e0..3f9de1fe74 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -270,7 +270,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
.setParent(this))
val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model,
wlsModel.diagInvAtWA.toArray, 1, getSolver)
- return model.setSummary(trainingSummary)
+ return model.setSummary(Some(trainingSummary))
}
// Fit Generalized Linear Model by iteratively reweighted least squares (IRLS).
@@ -284,7 +284,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
.setParent(this))
val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model,
irlsModel.diagInvAtWA.toArray, irlsModel.numIterations, getSolver)
- model.setSummary(trainingSummary)
+ model.setSummary(Some(trainingSummary))
}
@Since("2.0.0")
@@ -761,8 +761,8 @@ class GeneralizedLinearRegressionModel private[ml] (
def hasSummary: Boolean = trainingSummary.nonEmpty
private[regression]
- def setSummary(summary: GeneralizedLinearRegressionTrainingSummary): this.type = {
- this.trainingSummary = Some(summary)
+ def setSummary(summary: Option[GeneralizedLinearRegressionTrainingSummary]): this.type = {
+ this.trainingSummary = summary
this
}
@@ -778,8 +778,7 @@ class GeneralizedLinearRegressionModel private[ml] (
override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = {
val copied = copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept),
extra)
- if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
- copied.setParent(parent)
+ copied.setSummary(trainingSummary).setParent(parent)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index da7ce6b46f..8ea5e1e6c4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -225,7 +225,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
model.diagInvAtWA.toArray,
model.objectiveHistory)
- return lrModel.setSummary(trainingSummary)
+ return lrModel.setSummary(Some(trainingSummary))
}
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
@@ -278,7 +278,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
model,
Array(0D),
Array(0D))
- return model.setSummary(trainingSummary)
+ return model.setSummary(Some(trainingSummary))
} else {
require($(regParam) == 0.0, "The standard deviation of the label is zero. " +
"Model cannot be regularized.")
@@ -400,7 +400,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
model,
Array(0D),
objectiveHistory)
- model.setSummary(trainingSummary)
+ model.setSummary(Some(trainingSummary))
}
@Since("1.4.0")
@@ -446,8 +446,9 @@ class LinearRegressionModel private[ml] (
throw new SparkException("No training summary available for this LinearRegressionModel")
}
- private[regression] def setSummary(summary: LinearRegressionTrainingSummary): this.type = {
- this.trainingSummary = Some(summary)
+ private[regression]
+ def setSummary(summary: Option[LinearRegressionTrainingSummary]): this.type = {
+ this.trainingSummary = summary
this
}
@@ -490,8 +491,7 @@ class LinearRegressionModel private[ml] (
@Since("1.4.0")
override def copy(extra: ParamMap): LinearRegressionModel = {
val newModel = copyValues(new LinearRegressionModel(uid, coefficients, intercept), extra)
- if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
- newModel.setParent(parent)
+ newModel.setSummary(trainingSummary).setParent(parent)
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 2877285eb4..e360542eae 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -147,6 +147,8 @@ class LogisticRegressionSuite
assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
assert(copiedModel.hasSummary)
+ model.setSummary(None)
+ assert(!model.hasSummary)
}
test("empty probabilityCol") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index 49797d938d..fc491cd616 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -109,6 +109,9 @@ class BisectingKMeansSuite
assert(clusterSizes.length === k)
assert(clusterSizes.sum === numRows)
assert(clusterSizes.forall(_ >= 0))
+
+ model.setSummary(None)
+ assert(!model.hasSummary)
}
test("read/write") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
index 7165b63ed3..07299123f8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
@@ -111,6 +111,9 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
assert(clusterSizes.length === k)
assert(clusterSizes.sum === numRows)
assert(clusterSizes.forall(_ >= 0))
+
+ model.setSummary(None)
+ assert(!model.hasSummary)
}
test("read/write") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index 73972557d2..c1b7242e11 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -123,6 +123,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(clusterSizes.length === k)
assert(clusterSizes.sum === numRows)
assert(clusterSizes.forall(_ >= 0))
+
+ model.setSummary(None)
+ assert(!model.hasSummary)
}
test("KMeansModel transform with non-default feature and prediction cols") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index 6a4ac1735b..9b0fa67630 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -197,6 +197,8 @@ class GeneralizedLinearRegressionSuite
assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
assert(copiedModel.hasSummary)
+ model.setSummary(None)
+ assert(!model.hasSummary)
assert(model.getFeaturesCol === "features")
assert(model.getPredictionCol === "prediction")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index df97d0b2ae..0be82742a3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -146,6 +146,8 @@ class LinearRegressionSuite
assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
assert(copiedModel.hasSummary)
+ model.setSummary(None)
+ assert(!model.hasSummary)
model.transform(datasetWithDenseFeature)
.select("label", "prediction")
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 56c8c62259..83e1e89347 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -309,13 +309,16 @@ class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable
@since("2.0.0")
def summary(self):
"""
- Gets summary (e.g. residuals, mse, r-squared ) of model on
- training set. An exception is thrown if
- `trainingSummary is None`.
+ Gets summary (e.g. accuracy/precision/recall, objective history, total iterations) of model
+ trained on the training set. An exception is thrown if `trainingSummary is None`.
"""
- java_blrt_summary = self._call_java("summary")
- # Note: Once multiclass is added, update this to return correct summary
- return BinaryLogisticRegressionTrainingSummary(java_blrt_summary)
+ if self.hasSummary:
+ java_blrt_summary = self._call_java("summary")
+ # Note: Once multiclass is added, update this to return correct summary
+ return BinaryLogisticRegressionTrainingSummary(java_blrt_summary)
+ else:
+ raise RuntimeError("No training summary available for this %s" %
+ self.__class__.__name__)
@property
@since("2.0.0")
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 7632f05c3b..e58ec1e7ac 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -17,16 +17,74 @@
from pyspark import since, keyword_only
from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaEstimator, JavaModel
+from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper
from pyspark.ml.param.shared import *
from pyspark.ml.common import inherit_doc
-__all__ = ['BisectingKMeans', 'BisectingKMeansModel',
+__all__ = ['BisectingKMeans', 'BisectingKMeansModel', 'BisectingKMeansSummary',
'KMeans', 'KMeansModel',
- 'GaussianMixture', 'GaussianMixtureModel',
+ 'GaussianMixture', 'GaussianMixtureModel', 'GaussianMixtureSummary',
'LDA', 'LDAModel', 'LocalLDAModel', 'DistributedLDAModel']
+class ClusteringSummary(JavaWrapper):
+ """
+ .. note:: Experimental
+
+ Clustering results for a given model.
+
+ .. versionadded:: 2.1.0
+ """
+
+ @property
+ @since("2.1.0")
+ def predictionCol(self):
+ """
+ Name for column of predicted clusters in `predictions`.
+ """
+ return self._call_java("predictionCol")
+
+ @property
+ @since("2.1.0")
+ def predictions(self):
+ """
+ DataFrame produced by the model's `transform` method.
+ """
+ return self._call_java("predictions")
+
+ @property
+ @since("2.1.0")
+ def featuresCol(self):
+ """
+ Name for column of features in `predictions`.
+ """
+ return self._call_java("featuresCol")
+
+ @property
+ @since("2.1.0")
+ def k(self):
+ """
+ The number of clusters the model was trained with.
+ """
+ return self._call_java("k")
+
+ @property
+ @since("2.1.0")
+ def cluster(self):
+ """
+ DataFrame of predicted cluster centers for each training data point.
+ """
+ return self._call_java("cluster")
+
+ @property
+ @since("2.1.0")
+ def clusterSizes(self):
+ """
+ Size of (number of data points in) each cluster.
+ """
+ return self._call_java("clusterSizes")
+
+
class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
.. note:: Experimental
@@ -56,6 +114,28 @@ class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
return self._call_java("gaussiansDF")
+ @property
+ @since("2.1.0")
+ def hasSummary(self):
+ """
+ Indicates whether a training summary exists for this model
+ instance.
+ """
+ return self._call_java("hasSummary")
+
+ @property
+ @since("2.1.0")
+ def summary(self):
+ """
+ Gets summary (e.g. cluster assignments, cluster sizes) of the model trained on the
+ training set. An exception is thrown if no summary exists.
+ """
+ if self.hasSummary:
+ return GaussianMixtureSummary(self._call_java("summary"))
+ else:
+ raise RuntimeError("No training summary available for this %s" %
+ self.__class__.__name__)
+
@inherit_doc
class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed,
@@ -92,6 +172,13 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte
>>> gm = GaussianMixture(k=3, tol=0.0001,
... maxIter=10, seed=10)
>>> model = gm.fit(df)
+ >>> model.hasSummary
+ True
+ >>> summary = model.summary
+ >>> summary.k
+ 3
+ >>> summary.clusterSizes
+ [2, 2, 2]
>>> weights = model.weights
>>> len(weights)
3
@@ -118,6 +205,8 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte
>>> model_path = temp_path + "/gmm_model"
>>> model.save(model_path)
>>> model2 = GaussianMixtureModel.load(model_path)
+ >>> model2.hasSummary
+ False
>>> model2.weights == model.weights
True
>>> model2.gaussiansDF.show()
@@ -181,6 +270,32 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte
return self.getOrDefault(self.k)
+class GaussianMixtureSummary(ClusteringSummary):
+ """
+ .. note:: Experimental
+
+ Gaussian mixture clustering results for a given model.
+
+ .. versionadded:: 2.1.0
+ """
+
+ @property
+ @since("2.1.0")
+ def probabilityCol(self):
+ """
+ Name for column of predicted probability of each cluster in `predictions`.
+ """
+ return self._call_java("probabilityCol")
+
+ @property
+ @since("2.1.0")
+ def probability(self):
+ """
+ DataFrame of probabilities of each cluster for each training data point.
+ """
+ return self._call_java("probability")
+
+
class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by KMeans.
@@ -346,6 +461,27 @@ class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
return self._call_java("computeCost", dataset)
+ @property
+ @since("2.1.0")
+ def hasSummary(self):
+ """
+ Indicates whether a training summary exists for this model instance.
+ """
+ return self._call_java("hasSummary")
+
+ @property
+ @since("2.1.0")
+ def summary(self):
+ """
+ Gets summary (e.g. cluster assignments, cluster sizes) of the model trained on the
+ training set. An exception is thrown if no summary exists.
+ """
+ if self.hasSummary:
+ return BisectingKMeansSummary(self._call_java("summary"))
+ else:
+ raise RuntimeError("No training summary available for this %s" %
+ self.__class__.__name__)
+
@inherit_doc
class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasSeed,
@@ -373,6 +509,13 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte
2
>>> model.computeCost(df)
2.000...
+ >>> model.hasSummary
+ True
+ >>> summary = model.summary
+ >>> summary.k
+ 2
+ >>> summary.clusterSizes
+ [2, 2]
>>> transformed = model.transform(df).select("features", "prediction")
>>> rows = transformed.collect()
>>> rows[0].prediction == rows[1].prediction
@@ -387,6 +530,8 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte
>>> model_path = temp_path + "/bkm_model"
>>> model.save(model_path)
>>> model2 = BisectingKMeansModel.load(model_path)
+ >>> model2.hasSummary
+ False
>>> model.clusterCenters()[0] == model2.clusterCenters()[0]
array([ True, True], dtype=bool)
>>> model.clusterCenters()[1] == model2.clusterCenters()[1]
@@ -460,6 +605,17 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte
return BisectingKMeansModel(java_model)
+class BisectingKMeansSummary(ClusteringSummary):
+ """
+ .. note:: Experimental
+
+ Bisecting KMeans clustering results for a given model.
+
+ .. versionadded:: 2.1.0
+ """
+ pass
+
+
@inherit_doc
class LDAModel(JavaModel):
"""
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 0bc319ca4d..385391ba53 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -160,8 +160,12 @@ class LinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, Java
training set. An exception is thrown if
`trainingSummary is None`.
"""
- java_lrt_summary = self._call_java("summary")
- return LinearRegressionTrainingSummary(java_lrt_summary)
+ if self.hasSummary:
+ java_lrt_summary = self._call_java("summary")
+ return LinearRegressionTrainingSummary(java_lrt_summary)
+ else:
+ raise RuntimeError("No training summary available for this %s" %
+ self.__class__.__name__)
@property
@since("2.0.0")
@@ -1459,8 +1463,12 @@ class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWri
training set. An exception is thrown if
`trainingSummary is None`.
"""
- java_glrt_summary = self._call_java("summary")
- return GeneralizedLinearRegressionTrainingSummary(java_glrt_summary)
+ if self.hasSummary:
+ java_glrt_summary = self._call_java("summary")
+ return GeneralizedLinearRegressionTrainingSummary(java_glrt_summary)
+ else:
+ raise RuntimeError("No training summary available for this %s" %
+ self.__class__.__name__)
@property
@since("2.0.0")
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 9d46cc3b4a..c0f0d40735 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -1097,6 +1097,38 @@ class TrainingSummaryTest(SparkSessionTestCase):
sameSummary = model.evaluate(df)
self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC)
+ def test_gaussian_mixture_summary(self):
+ data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),),
+ (Vectors.sparse(1, [], []),)]
+ df = self.spark.createDataFrame(data, ["features"])
+ gmm = GaussianMixture(k=2)
+ model = gmm.fit(df)
+ self.assertTrue(model.hasSummary)
+ s = model.summary
+ self.assertTrue(isinstance(s.predictions, DataFrame))
+ self.assertEqual(s.probabilityCol, "probability")
+ self.assertTrue(isinstance(s.probability, DataFrame))
+ self.assertEqual(s.featuresCol, "features")
+ self.assertEqual(s.predictionCol, "prediction")
+ self.assertTrue(isinstance(s.cluster, DataFrame))
+ self.assertEqual(len(s.clusterSizes), 2)
+ self.assertEqual(s.k, 2)
+
+ def test_bisecting_kmeans_summary(self):
+ data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),),
+ (Vectors.sparse(1, [], []),)]
+ df = self.spark.createDataFrame(data, ["features"])
+ bkm = BisectingKMeans(k=2)
+ model = bkm.fit(df)
+ self.assertTrue(model.hasSummary)
+ s = model.summary
+ self.assertTrue(isinstance(s.predictions, DataFrame))
+ self.assertEqual(s.featuresCol, "features")
+ self.assertEqual(s.predictionCol, "prediction")
+ self.assertTrue(isinstance(s.cluster, DataFrame))
+ self.assertEqual(len(s.clusterSizes), 2)
+ self.assertEqual(s.k, 2)
+
class OneVsRestTests(SparkSessionTestCase):