aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala14
-rwxr-xr-xpython/pyspark/ml/feature.py12
3 files changed, 26 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 2916b6d9df..a7ca0fe252 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -182,7 +182,7 @@ class RFormula(override val uid: String)
override def copy(extra: ParamMap): RFormula = defaultCopy(extra)
- override def toString: String = s"RFormula(${get(formula)}) (uid=$uid)"
+ override def toString: String = s"RFormula(${get(formula).getOrElse("")}) (uid=$uid)"
}
@Since("2.0.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
index 19aecff038..2dd565a782 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
@@ -126,7 +126,19 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) {
* @param hasIntercept whether the formula specifies fitting with an intercept.
*/
private[ml] case class ResolvedRFormula(
- label: String, terms: Seq[Seq[String]], hasIntercept: Boolean)
+ label: String, terms: Seq[Seq[String]], hasIntercept: Boolean) {
+
+ override def toString: String = {
+ val ts = terms.map {
+ case t if t.length > 1 =>
+ s"${t.mkString("{", ",", "}")}"
+ case t =>
+ t.mkString
+ }
+ val termStr = ts.mkString("[", ",", "]")
+ s"ResolvedRFormula(label=$label, terms=$termStr, hasIntercept=$hasIntercept)"
+ }
+}
/**
* R formula terms. See the R formula docs here for more information:
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index bfb2fb7071..ca77ac395d 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2528,6 +2528,8 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM
True
>>> loadedRF.getLabelCol() == rf.getLabelCol()
True
+ >>> str(loadedRF)
+ 'RFormula(y ~ x + s) (uid=...)'
>>> modelPath = temp_path + "/rFormulaModel"
>>> model.save(modelPath)
>>> loadedModel = RFormulaModel.load(modelPath)
@@ -2542,6 +2544,8 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM
|0.0|0.0| a|[0.0,1.0]| 0.0|
+---+---+---+---------+-----+
...
+ >>> str(loadedModel)
+ 'RFormulaModel(ResolvedRFormula(label=y, terms=[x,s], hasIntercept=true)) (uid=...)'
.. versionadded:: 1.5.0
"""
@@ -2586,6 +2590,10 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM
def _create_model(self, java_model):
return RFormulaModel(java_model)
+ def __str__(self):
+ formulaStr = self.getFormula() if self.isDefined(self.formula) else ""
+ return "RFormula(%s) (uid=%s)" % (formulaStr, self.uid)
+
class RFormulaModel(JavaModel, JavaMLReadable, JavaMLWritable):
"""
@@ -2597,6 +2605,10 @@ class RFormulaModel(JavaModel, JavaMLReadable, JavaMLWritable):
.. versionadded:: 1.5.0
"""
+ def __str__(self):
+ resolvedFormula = self._call_java("resolvedFormula")
+ return "RFormulaModel(%s) (uid=%s)" % (resolvedFormula, self.uid)
+
@inherit_doc
class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, JavaMLReadable,