aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala12
1 files changed, 10 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index af68e7b9d5..2f78dd30b3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -227,6 +227,12 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
val numFeatures = featuresStd.size
+ val instr = Instrumentation.create(this, dataset)
+ instr.logParams(labelCol, featuresCol, censorCol, predictionCol, quantilesCol,
+ fitIntercept, maxIter, tol, aggregationDepth)
+ instr.logNamedValue("quantileProbabilities.size", $(quantileProbabilities).length)
+ instr.logNumFeatures(numFeatures)
+
if (!$(fitIntercept) && (0 until numFeatures).exists { i =>
featuresStd(i) == 0.0 && featuresSummarizer.mean(i) != 0.0 }) {
logWarning("Fitting AFTSurvivalRegressionModel without intercept on dataset with " +
@@ -276,8 +282,10 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
val coefficients = Vectors.dense(rawCoefficients)
val intercept = parameters(1)
val scale = math.exp(parameters(0))
- val model = new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale)
- copyValues(model.setParent(this))
+ val model = copyValues(new AFTSurvivalRegressionModel(uid, coefficients,
+ intercept, scale).setParent(this))
+ instr.logSuccess(model)
+ model
}
@Since("1.6.0")