aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorvectorijk <jiangkai@gmail.com>2015-10-06 12:43:28 -0700
committerXiangrui Meng <meng@databricks.com>2015-10-06 12:43:28 -0700
commit5952bdb7df20d007d59f82261095faca3822c6f6 (patch)
treefe4b4f0df4e47173d70cff41144acf4b436b93de
parente9783601599758df87418bf61a7b4636f06714fa (diff)
downloadspark-5952bdb7df20d007d59f82261095faca3822c6f6.tar.gz
spark-5952bdb7df20d007d59f82261095faca3822c6f6.tar.bz2
spark-5952bdb7df20d007d59f82261095faca3822c6f6.zip
[SPARK-10688] [ML] [PYSPARK] Python API for AFTSurvivalRegression
Implement Python API for AFTSurvivalRegression Author: vectorijk <jiangkai@gmail.com> Closes #8926 from vectorijk/spark-10688.
-rw-r--r--python/pyspark/ml/regression.py171
1 files changed, 169 insertions, 2 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 21d454f900..a0f7f54e65 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -22,8 +22,10 @@ from pyspark.ml.param.shared import *
from pyspark.mllib.common import inherit_doc
-__all__ = ['DecisionTreeRegressor', 'DecisionTreeRegressionModel', 'GBTRegressor',
- 'GBTRegressionModel', 'LinearRegression', 'LinearRegressionModel',
+__all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
+ 'DecisionTreeRegressor', 'DecisionTreeRegressionModel',
+ 'GBTRegressor', 'GBTRegressionModel',
+ 'LinearRegression', 'LinearRegressionModel',
'RandomForestRegressor', 'RandomForestRegressionModel']
@@ -609,6 +611,171 @@ class GBTRegressionModel(TreeEnsembleModels):
"""
+@inherit_doc
+class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
+ HasFitIntercept, HasMaxIter, HasTol):
+ """
+ Accelerated Failure Time (AFT) Model Survival Regression
+
+ Fit a parametric AFT survival regression model based on the Weibull distribution
+ of the survival time.
+
+ .. seealso:: `AFT Model <https://en.wikipedia.org/wiki/Accelerated_failure_time_model>`_
+
+ >>> from pyspark.mllib.linalg import Vectors
+ >>> df = sqlContext.createDataFrame([
+ ... (1.0, Vectors.dense(1.0), 1.0),
+ ... (0.0, Vectors.sparse(1, [], []), 0.0)], ["label", "features", "censor"])
+ >>> aftsr = AFTSurvivalRegression()
+ >>> model = aftsr.fit(df)
+ >>> model.predict(Vectors.dense(6.3))
+ 1.0
+ >>> model.predictQuantiles(Vectors.dense(6.3))
+ DenseVector([0.0101, 0.0513, 0.1054, 0.2877, 0.6931, 1.3863, 2.3026, 2.9957, 4.6052])
+ >>> model.transform(df).show()
+ +-----+---------+------+----------+
+ |label| features|censor|prediction|
+ +-----+---------+------+----------+
+ | 1.0| [1.0]| 1.0| 1.0|
+ | 0.0|(1,[],[])| 0.0| 1.0|
+ +-----+---------+------+----------+
+ ...
+
+ .. versionadded:: 1.6.0
+ """
+
+ # a placeholder to make it appear in the generated doc
+ censorCol = Param(Params._dummy(), "censorCol",
+ "censor column name. The value of this column could be 0 or 1. " +
+ "If the value is 1, it means the event has occurred i.e. " +
+ "uncensored; otherwise censored.")
+ quantileProbabilities = \
+ Param(Params._dummy(), "quantileProbabilities",
+ "quantile probabilities array. Values of the quantile probabilities array " +
+ "should be in the range (0, 1) and the array should be non-empty.")
+ quantilesCol = Param(Params._dummy(), "quantilesCol",
+ "quantiles column name. This column will output quantiles of " +
+ "corresponding quantileProbabilities if it is set.")
+
+ @keyword_only
+ def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+ fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
+ quantileProbabilities=None, quantilesCol=None):
+ """
+ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
+ fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \
+ quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
+ quantilesCol=None):
+ """
+ super(AFTSurvivalRegression, self).__init__()
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.regression.AFTSurvivalRegression", self.uid)
+ #: Param for censor column name
+ self.censorCol = Param(self, "censorCol",
+ "censor column name. The value of this column could be 0 or 1. " +
+ "If the value is 1, it means the event has occurred i.e. " +
+ "uncensored; otherwise censored.")
+ #: Param for quantile probabilities array
+ self.quantileProbabilities = \
+ Param(self, "quantileProbabilities",
+ "quantile probabilities array. Values of the quantile probabilities array " +
+ "should be in the range (0, 1) and the array should be non-empty.")
+ #: Param for quantiles column name
+ self.quantilesCol = Param(self, "quantilesCol",
+ "quantiles column name. This column will output quantiles of " +
+ "corresponding quantileProbabilities if it is set.")
+ self._setDefault(censorCol="censor",
+ quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ @since("1.6.0")
+ def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+ fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
+ quantileProbabilities=None, quantilesCol=None):
+ """
+ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
+ fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \
+ quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
+ quantilesCol=None):
+ """
+ kwargs = self.setParams._input_kwargs
+ if quantileProbabilities is None:
+ return self._set(**kwargs).setQuantileProbabilities([0.01, 0.05, 0.1, 0.25, 0.5,
+ 0.75, 0.9, 0.95, 0.99])
+ else:
+ return self._set(**kwargs)
+
+ def _create_model(self, java_model):
+ return AFTSurvivalRegressionModel(java_model)
+
+ @since("1.6.0")
+ def setCensorCol(self, value):
+ """
+ Sets the value of :py:attr:`censorCol`.
+ """
+ self._paramMap[self.censorCol] = value
+ return self
+
+ @since("1.6.0")
+ def getCensorCol(self):
+ """
+ Gets the value of censorCol or its default value.
+ """
+ return self.getOrDefault(self.censorCol)
+
+ @since("1.6.0")
+ def setQuantileProbabilities(self, value):
+ """
+ Sets the value of :py:attr:`quantileProbabilities`.
+ """
+ self._paramMap[self.quantileProbabilities] = value
+ return self
+
+ @since("1.6.0")
+ def getQuantileProbabilities(self):
+ """
+ Gets the value of quantileProbabilities or its default value.
+ """
+ return self.getOrDefault(self.quantileProbabilities)
+
+ @since("1.6.0")
+ def setQuantilesCol(self, value):
+ """
+ Sets the value of :py:attr:`quantilesCol`.
+ """
+ self._paramMap[self.quantilesCol] = value
+ return self
+
+ @since("1.6.0")
+ def getQuantilesCol(self):
+ """
+ Gets the value of quantilesCol or its default value.
+ """
+ return self.getOrDefault(self.quantilesCol)
+
+
+class AFTSurvivalRegressionModel(JavaModel):
+ """
+ Model fitted by AFTSurvivalRegression.
+
+ .. versionadded:: 1.6.0
+ """
+
+ def predictQuantiles(self, features):
+ """
+ Predicted Quantiles
+ """
+ return self._call_java("predictQuantiles", features)
+
+ def predict(self, features):
+ """
+ Predicted value
+ """
+ return self._call_java("predict", features)
+
+
if __name__ == "__main__":
import doctest
from pyspark.context import SparkContext