From 3434572b141075f00698d94e6ee80febd3093c3b Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 3 Nov 2015 08:31:16 -0800 Subject: [MINOR][ML] Fix naming conventions of AFTSurvivalRegression coefficients Rename ```regressionCoefficients``` back to ```coefficients```, and name ```weights``` to ```parameters```. See discussion [here](https://github.com/apache/spark/pull/9311/files#diff-e277fd0bc21f825d3196b4551c01fe5fR230). mengxr vectorijk dbtsai Author: Yanbo Liang Closes #9431 from yanboliang/aft-coefficients. --- .../ml/regression/AFTSurvivalRegression.scala | 38 +++++++++++----------- .../ml/regression/AFTSurvivalRegressionSuite.scala | 12 +++---- 2 files changed, 25 insertions(+), 25 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 4dbbc7d399..b7d095872f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -200,17 +200,17 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S val numFeatures = dataset.select($(featuresCol)).take(1)(0).getAs[Vector](0).size /* - The coefficients vector has three parts: + The parameters vector has three parts: the first element: Double, log(sigma), the log of scale parameter the second element: Double, intercept of the beta parameter the third to the end elements: Doubles, regression coefficients vector of the beta parameter */ - val initialCoefficients = Vectors.zeros(numFeatures + 2) + val initialParameters = Vectors.zeros(numFeatures + 2) val states = optimizer.iterations(new CachedDiffFunction(costFun), - initialCoefficients.toBreeze.toDenseVector) + initialParameters.toBreeze.toDenseVector) - val coefficients = { + val parameters = { val arrayBuilder = mutable.ArrayBuilder.make[Double] var state: optimizer.State = null while (states.hasNext) { @@ -227,10 +227,10 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S if (handlePersistence) instances.unpersist() - val regressionCoefficients = Vectors.dense(coefficients.slice(2, coefficients.length)) - val intercept = coefficients(1) - val scale = math.exp(coefficients(0)) - val model = new AFTSurvivalRegressionModel(uid, regressionCoefficients, intercept, scale) + val coefficients = Vectors.dense(parameters.slice(2, parameters.length)) + val intercept = parameters(1) + val scale = math.exp(parameters(0)) + val model = new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale) copyValues(model.setParent(this)) } @@ -251,7 +251,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S @Since("1.6.0") class AFTSurvivalRegressionModel private[ml] ( @Since("1.6.0") override val uid: String, - @Since("1.6.0") val regressionCoefficients: Vector, + @Since("1.6.0") val coefficients: Vector, @Since("1.6.0") val intercept: Double, @Since("1.6.0") val scale: Double) extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams { @@ -275,7 +275,7 @@ class AFTSurvivalRegressionModel private[ml] ( @Since("1.6.0") def predictQuantiles(features: Vector): Vector = { // scale parameter for the Weibull distribution of lifetime - val lambda = math.exp(BLAS.dot(regressionCoefficients, features) + intercept) + val lambda = math.exp(BLAS.dot(coefficients, features) + intercept) // shape parameter for the Weibull distribution of lifetime val k = 1 / scale val quantiles = $(quantileProbabilities).map { @@ -286,7 +286,7 @@ class AFTSurvivalRegressionModel private[ml] ( @Since("1.6.0") def predict(features: Vector): Double = { - math.exp(BLAS.dot(regressionCoefficients, features) + intercept) + math.exp(BLAS.dot(coefficients, features) + intercept) } @Since("1.6.0") @@ -309,7 +309,7 @@ class AFTSurvivalRegressionModel private[ml] ( @Since("1.6.0") override def copy(extra: ParamMap): AFTSurvivalRegressionModel = { - copyValues(new AFTSurvivalRegressionModel(uid, regressionCoefficients, intercept, scale), extra) + copyValues(new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale), extra) .setParent(parent) } } @@ -369,17 +369,17 @@ class AFTSurvivalRegressionModel private[ml] ( * \frac{\partial (-\iota)}{\partial (\log\sigma)}= * \sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}] * }}} - * @param coefficients including three part: The log of scale parameter, the intercept and + * @param parameters including three part: The log of scale parameter, the intercept and * regression coefficients corresponding to the features. * @param fitIntercept Whether to fit an intercept term. */ -private class AFTAggregator(coefficients: BDV[Double], fitIntercept: Boolean) +private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) extends Serializable { // beta is the intercept and regression coefficients to the covariates - private val beta = coefficients.slice(1, coefficients.length) + private val beta = parameters.slice(1, parameters.length) // sigma is the scale parameter of the AFT model - private val sigma = math.exp(coefficients(0)) + private val sigma = math.exp(parameters(0)) private var totalCnt: Long = 0L private var lossSum = 0.0 @@ -449,15 +449,15 @@ private class AFTAggregator(coefficients: BDV[Double], fitIntercept: Boolean) /** * AFTCostFun implements Breeze's DiffFunction[T] for AFT cost. - * It returns the loss and gradient at a particular point (coefficients). + * It returns the loss and gradient at a particular point (parameters). * It's used in Breeze's convex optimization routines. */ private class AFTCostFun(data: RDD[AFTPoint], fitIntercept: Boolean) extends DiffFunction[BDV[Double]] { - override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { + override def calculate(parameters: BDV[Double]): (Double, BDV[Double]) = { - val aftAggregator = data.treeAggregate(new AFTAggregator(coefficients, fitIntercept))( + val aftAggregator = data.treeAggregate(new AFTAggregator(parameters, fitIntercept))( seqOp = (c, v) => (c, v) match { case (aggregator, instance) => aggregator.add(instance) }, diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index c0f791bce1..359f310271 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -141,12 +141,12 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex Number of Newton-Raphson Iterations: 5 n= 1000 */ - val regressionCoefficientsR = Vectors.dense(-0.039) + val coefficientsR = Vectors.dense(-0.039) val interceptR = 1.759 val scaleR = 1.41 assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.regressionCoefficients ~== regressionCoefficientsR relTol 1E-3) + assert(model.coefficients ~== coefficientsR relTol 1E-3) assert(model.scale ~== scaleR relTol 1E-3) /* @@ -212,12 +212,12 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex Number of Newton-Raphson Iterations: 5 n= 1000 */ - val regressionCoefficientsR = Vectors.dense(-0.0844, 0.0677) + val coefficientsR = Vectors.dense(-0.0844, 0.0677) val interceptR = 1.9206 val scaleR = 0.977 assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.regressionCoefficients ~== regressionCoefficientsR relTol 1E-3) + assert(model.coefficients ~== coefficientsR relTol 1E-3) assert(model.scale ~== scaleR relTol 1E-3) /* @@ -282,12 +282,12 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex Number of Newton-Raphson Iterations: 6 n= 1000 */ - val regressionCoefficientsR = Vectors.dense(0.896, -0.709) + val coefficientsR = Vectors.dense(0.896, -0.709) val interceptR = 0.0 val scaleR = 1.52 assert(model.intercept === interceptR) - assert(model.regressionCoefficients ~== regressionCoefficientsR relTol 1E-3) + assert(model.coefficients ~== coefficientsR relTol 1E-3) assert(model.scale ~== scaleR relTol 1E-3) /* -- cgit v1.2.3