aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala11
1 files changed, 10 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 4a3fe5c663..c2b440059b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -273,6 +273,10 @@ class LogisticRegression @Since("1.2.0") (
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
+ val instr = Instrumentation.create(this, instances)
+ instr.logParams(regParam, elasticNetParam, standardization, threshold,
+ maxIter, tol, fitIntercept)
+
val (summarizer, labelSummarizer) = {
val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer),
instance: Instance) =>
@@ -291,6 +295,9 @@ class LogisticRegression @Since("1.2.0") (
val numClasses = histogram.length
val numFeatures = summarizer.mean.size
+ instr.logNumClasses(numClasses)
+ instr.logNumFeatures(numFeatures)
+
val (coefficients, intercept, objectiveHistory) = {
if (numInvalid != 0) {
val msg = s"Classification labels should be in {0 to ${numClasses - 1} " +
@@ -444,7 +451,9 @@ class LogisticRegression @Since("1.2.0") (
$(labelCol),
$(featuresCol),
objectiveHistory)
- model.setSummary(logRegSummary)
+ val m = model.setSummary(logRegSummary)
+ instr.logSuccess(m)
+ m
}
@Since("1.4.0")