aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/classification.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/classification.py')
-rw-r--r--python/pyspark/ml/classification.py38
1 files changed, 20 insertions, 18 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index a1c3f72984..ea660d7808 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -498,7 +498,7 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> from pyspark.ml.linalg import Vectors
>>> from pyspark.ml.feature import StringIndexer
- >>> df = sqlContext.createDataFrame([
+ >>> df = spark.createDataFrame([
... (1.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
@@ -512,7 +512,7 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
1
>>> model.featureImportances
SparseVector(1, {0: 1.0})
- >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
+ >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> result = model.transform(test0).head()
>>> result.prediction
0.0
@@ -520,7 +520,7 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
DenseVector([1.0, 0.0])
>>> result.rawPrediction
DenseVector([1.0, 0.0])
- >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
+ >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
1.0
@@ -627,7 +627,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> from numpy import allclose
>>> from pyspark.ml.linalg import Vectors
>>> from pyspark.ml.feature import StringIndexer
- >>> df = sqlContext.createDataFrame([
+ >>> df = spark.createDataFrame([
... (1.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
@@ -639,7 +639,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
SparseVector(1, {0: 1.0})
>>> allclose(model.treeWeights, [1.0, 1.0, 1.0])
True
- >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
+ >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> result = model.transform(test0).head()
>>> result.prediction
0.0
@@ -647,7 +647,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
0
>>> numpy.argmax(result.rawPrediction)
0
- >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
+ >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
1.0
>>> rfc_path = temp_path + "/rfc"
@@ -754,7 +754,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
>>> from numpy import allclose
>>> from pyspark.ml.linalg import Vectors
>>> from pyspark.ml.feature import StringIndexer
- >>> df = sqlContext.createDataFrame([
+ >>> df = spark.createDataFrame([
... (1.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
@@ -766,10 +766,10 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
SparseVector(1, {0: 1.0})
>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
True
- >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
+ >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
- >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
+ >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
1.0
>>> gbtc_path = temp_path + "gbtc"
@@ -885,7 +885,7 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
>>> from pyspark.sql import Row
>>> from pyspark.ml.linalg import Vectors
- >>> df = sqlContext.createDataFrame([
+ >>> df = spark.createDataFrame([
... Row(label=0.0, features=Vectors.dense([0.0, 0.0])),
... Row(label=0.0, features=Vectors.dense([0.0, 1.0])),
... Row(label=1.0, features=Vectors.dense([1.0, 0.0]))])
@@ -1029,7 +1029,7 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol,
Number of outputs has to be equal to the total number of labels.
>>> from pyspark.ml.linalg import Vectors
- >>> df = sqlContext.createDataFrame([
+ >>> df = spark.createDataFrame([
... (0.0, Vectors.dense([0.0, 0.0])),
... (1.0, Vectors.dense([0.0, 1.0])),
... (1.0, Vectors.dense([1.0, 0.0])),
@@ -1040,7 +1040,7 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol,
[2, 5, 2]
>>> model.weights.size
27
- >>> testDF = sqlContext.createDataFrame([
+ >>> testDF = spark.createDataFrame([
... (Vectors.dense([1.0, 0.0]),),
... (Vectors.dense([0.0, 0.0]),)], ["features"])
>>> model.transform(testDF).show()
@@ -1467,21 +1467,23 @@ class OneVsRestModel(Model, OneVsRestParams, MLReadable, MLWritable):
if __name__ == "__main__":
import doctest
import pyspark.ml.classification
- from pyspark.context import SparkContext
- from pyspark.sql import SQLContext
+ from pyspark.sql import SparkSession
globs = pyspark.ml.classification.__dict__.copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
- sc = SparkContext("local[2]", "ml.classification tests")
- sqlContext = SQLContext(sc)
+ spark = SparkSession.builder\
+ .master("local[2]")\
+ .appName("ml.classification tests")\
+ .getOrCreate()
+ sc = spark.sparkContext
globs['sc'] = sc
- globs['sqlContext'] = sqlContext
+ globs['spark'] = spark
import tempfile
temp_path = tempfile.mkdtemp()
globs['temp_path'] = temp_path
try:
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
- sc.stop()
+ spark.stop()
finally:
from shutil import rmtree
try: