aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
blob: c368aadd236696628a49b01b96892d2782ab8518 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.ml.param

import java.lang.reflect.Modifier
import java.util.{List => JList}
import java.util.NoSuchElementException

import scala.annotation.varargs
import scala.collection.mutable
import scala.collection.JavaConverters._

import org.json4s._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.{Vector, Vectors}

/**
 * :: DeveloperApi ::
 * A param with self-contained documentation and optionally default value. Primitive-typed param
 * should use the specialized versions, which are more friendly to Java users.
 *
 * @param parent parent object
 * @param name param name
 * @param doc documentation
 * @param isValid optional validation method which indicates if a value is valid.
 *                See [[ParamValidators]] for factory methods for common validation functions.
 * @tparam T param value type
 */
@DeveloperApi
class Param[T](val parent: String, val name: String, val doc: String, val isValid: T => Boolean)
  extends Serializable {

  def this(parent: Identifiable, name: String, doc: String, isValid: T => Boolean) =
    this(parent.uid, name, doc, isValid)

  def this(parent: String, name: String, doc: String) =
    this(parent, name, doc, ParamValidators.alwaysTrue[T])

  def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)

  /**
   * Assert that the given value is valid for this parameter.
   *
   * Note: Parameter checks involving interactions between multiple parameters and input/output
   * columns should be implemented in [[org.apache.spark.ml.PipelineStage.transformSchema()]].
   *
   * DEVELOPERS: This method is only called by [[ParamPair]], which means that all parameters
   *             should be specified via [[ParamPair]].
   *
   * @throws IllegalArgumentException if the value is invalid
   */
  private[param] def validate(value: T): Unit = {
    if (!isValid(value)) {
      val valueToString = value match {
        case v: Array[_] => v.mkString("[", ",", "]")
        case _ => value.toString
      }
      throw new IllegalArgumentException(
        s"$parent parameter $name given invalid value $valueToString.")
    }
  }

  /** Creates a param pair with the given value (for Java). */
  def w(value: T): ParamPair[T] = this -> value

  /** Creates a param pair with the given value (for Scala). */
  // scalastyle:off
  def ->(value: T): ParamPair[T] = ParamPair(this, value)
  // scalastyle:on

  /** Encodes a param value into JSON, which can be decoded by [[jsonDecode()]]. */
  def jsonEncode(value: T): String = {
    value match {
      case x: String =>
        compact(render(JString(x)))
      case v: Vector =>
        v.toJson
      case _ =>
        throw new NotImplementedError(
          "The default jsonEncode only supports string and vector. " +
            s"${this.getClass.getName} must override jsonEncode for ${value.getClass.getName}.")
    }
  }

  /** Decodes a param value from JSON. */
  def jsonDecode(json: String): T = Param.jsonDecode[T](json)

  private[this] val stringRepresentation = s"${parent}__$name"

  override final def toString: String = stringRepresentation

  override final def hashCode: Int = toString.##

  override final def equals(obj: Any): Boolean = {
    obj match {
      case p: Param[_] => (p.parent == parent) && (p.name == name)
      case _ => false
    }
  }
}

private[ml] object Param {

  /** Decodes a param value from JSON. */
  def jsonDecode[T](json: String): T = {
    parse(json) match {
      case JString(x) =>
        x.asInstanceOf[T]
      case JObject(v) =>
        val keys = v.map(_._1)
        assert(keys.contains("type") && keys.contains("values"),
          s"Expect a JSON serialized vector but cannot find fields 'type' and 'values' in $json.")
        Vectors.fromJson(json).asInstanceOf[T]
      case _ =>
        throw new NotImplementedError(
          "The default jsonDecode only supports string and vector. " +
            s"${this.getClass.getName} must override jsonDecode to support its value type.")
    }
  }
}

/**
 * :: DeveloperApi ::
 * Factory methods for common validation functions for [[Param.isValid]].
 * The numerical methods only support Int, Long, Float, and Double.
 */
@DeveloperApi
object ParamValidators {

  /** (private[param]) Default validation always return true */
  private[param] def alwaysTrue[T]: T => Boolean = (_: T) => true

  /**
   * Private method for checking numerical types and converting to Double.
   * This is mainly for the sake of compilation; type checks are really handled
   * by [[Params]] setters and the [[ParamPair]] constructor.
   */
  private def getDouble[T](value: T): Double = value match {
    case x: Int => x.toDouble
    case x: Long => x.toDouble
    case x: Float => x.toDouble
    case x: Double => x.toDouble
    case _ =>
      // The type should be checked before this is ever called.
      throw new IllegalArgumentException("Numerical Param validation failed because" +
        s" of unexpected input type: ${value.getClass}")
  }

  /** Check if value > lowerBound */
  def gt[T](lowerBound: Double): T => Boolean = { (value: T) =>
    getDouble(value) > lowerBound
  }

  /** Check if value >= lowerBound */
  def gtEq[T](lowerBound: Double): T => Boolean = { (value: T) =>
    getDouble(value) >= lowerBound
  }

  /** Check if value < upperBound */
  def lt[T](upperBound: Double): T => Boolean = { (value: T) =>
    getDouble(value) < upperBound
  }

  /** Check if value <= upperBound */
  def ltEq[T](upperBound: Double): T => Boolean = { (value: T) =>
    getDouble(value) <= upperBound
  }

  /**
   * Check for value in range lowerBound to upperBound.
   * @param lowerInclusive  If true, check for value >= lowerBound.
   *                        If false, check for value > lowerBound.
   * @param upperInclusive  If true, check for value <= upperBound.
   *                        If false, check for value < upperBound.
   */
  def inRange[T](
      lowerBound: Double,
      upperBound: Double,
      lowerInclusive: Boolean,
      upperInclusive: Boolean): T => Boolean = { (value: T) =>
    val x: Double = getDouble(value)
    val lowerValid = if (lowerInclusive) x >= lowerBound else x > lowerBound
    val upperValid = if (upperInclusive) x <= upperBound else x < upperBound
    lowerValid && upperValid
  }

  /** Version of [[inRange()]] which uses inclusive be default: [lowerBound, upperBound] */
  def inRange[T](lowerBound: Double, upperBound: Double): T => Boolean = {
    inRange[T](lowerBound, upperBound, lowerInclusive = true, upperInclusive = true)
  }

  /** Check for value in an allowed set of values. */
  def inArray[T](allowed: Array[T]): T => Boolean = { (value: T) =>
    allowed.contains(value)
  }

  /** Check for value in an allowed set of values. */
  def inArray[T](allowed: java.util.List[T]): T => Boolean = { (value: T) =>
    allowed.contains(value)
  }

  /** Check that the array length is greater than lowerBound. */
  def arrayLengthGt[T](lowerBound: Double): Array[T] => Boolean = { (value: Array[T]) =>
    value.length > lowerBound
  }
}

// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ...

/**
 * :: DeveloperApi ::
 * Specialized version of [[Param[Double]]] for Java.
 */
@DeveloperApi
class DoubleParam(parent: String, name: String, doc: String, isValid: Double => Boolean)
  extends Param[Double](parent, name, doc, isValid) {

  def this(parent: String, name: String, doc: String) =
    this(parent, name, doc, ParamValidators.alwaysTrue)

  def this(parent: Identifiable, name: String, doc: String, isValid: Double => Boolean) =
    this(parent.uid, name, doc, isValid)

  def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)

  /** Creates a param pair with the given value (for Java). */
  override def w(value: Double): ParamPair[Double] = super.w(value)

  override def jsonEncode(value: Double): String = {
    compact(render(DoubleParam.jValueEncode(value)))
  }

  override def jsonDecode(json: String): Double = {
    DoubleParam.jValueDecode(parse(json))
  }
}

private[param] object DoubleParam {
  /** Encodes a param value into JValue. */
  def jValueEncode(value: Double): JValue = {
    value match {
      case _ if value.isNaN =>
        JString("NaN")
      case Double.NegativeInfinity =>
        JString("-Inf")
      case Double.PositiveInfinity =>
        JString("Inf")
      case _ =>
        JDouble(value)
    }
  }

  /** Decodes a param value from JValue. */
  def jValueDecode(jValue: JValue): Double = {
    jValue match {
      case JString("NaN") =>
        Double.NaN
      case JString("-Inf") =>
        Double.NegativeInfinity
      case JString("Inf") =>
        Double.PositiveInfinity
      case JDouble(x) =>
        x
      case _ =>
        throw new IllegalArgumentException(s"Cannot decode $jValue to Double.")
    }
  }
}

/**
 * :: DeveloperApi ::
 * Specialized version of [[Param[Int]]] for Java.
 */
@DeveloperApi
class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolean)
  extends Param[Int](parent, name, doc, isValid) {

  def this(parent: String, name: String, doc: String) =
    this(parent, name, doc, ParamValidators.alwaysTrue)

  def this(parent: Identifiable, name: String, doc: String, isValid: Int => Boolean) =
    this(parent.uid, name, doc, isValid)

  def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)

  /** Creates a param pair with the given value (for Java). */
  override def w(value: Int): ParamPair[Int] = super.w(value)

  override def jsonEncode(value: Int): String = {
    compact(render(JInt(value)))
  }

  override def jsonDecode(json: String): Int = {
    implicit val formats = DefaultFormats
    parse(json).extract[Int]
  }
}

/**
 * :: DeveloperApi ::
 * Specialized version of [[Param[Float]]] for Java.
 */
@DeveloperApi
class FloatParam(parent: String, name: String, doc: String, isValid: Float => Boolean)
  extends Param[Float](parent, name, doc, isValid) {

  def this(parent: String, name: String, doc: String) =
    this(parent, name, doc, ParamValidators.alwaysTrue)

  def this(parent: Identifiable, name: String, doc: String, isValid: Float => Boolean) =
    this(parent.uid, name, doc, isValid)

  def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)

  /** Creates a param pair with the given value (for Java). */
  override def w(value: Float): ParamPair[Float] = super.w(value)

  override def jsonEncode(value: Float): String = {
    compact(render(FloatParam.jValueEncode(value)))
  }

  override def jsonDecode(json: String): Float = {
    FloatParam.jValueDecode(parse(json))
  }
}

private object FloatParam {

  /** Encodes a param value into JValue. */
  def jValueEncode(value: Float): JValue = {
    value match {
      case _ if value.isNaN =>
        JString("NaN")
      case Float.NegativeInfinity =>
        JString("-Inf")
      case Float.PositiveInfinity =>
        JString("Inf")
      case _ =>
        JDouble(value)
    }
  }

  /** Decodes a param value from JValue. */
  def jValueDecode(jValue: JValue): Float = {
    jValue match {
      case JString("NaN") =>
        Float.NaN
      case JString("-Inf") =>
        Float.NegativeInfinity
      case JString("Inf") =>
        Float.PositiveInfinity
      case JDouble(x) =>
        x.toFloat
      case _ =>
        throw new IllegalArgumentException(s"Cannot decode $jValue to Float.")
    }
  }
}

/**
 * :: DeveloperApi ::
 * Specialized version of [[Param[Long]]] for Java.
 */
@DeveloperApi
class LongParam(parent: String, name: String, doc: String, isValid: Long => Boolean)
  extends Param[Long](parent, name, doc, isValid) {

  def this(parent: String, name: String, doc: String) =
    this(parent, name, doc, ParamValidators.alwaysTrue)

  def this(parent: Identifiable, name: String, doc: String, isValid: Long => Boolean) =
    this(parent.uid, name, doc, isValid)

  def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)

  /** Creates a param pair with the given value (for Java). */
  override def w(value: Long): ParamPair[Long] = super.w(value)

  override def jsonEncode(value: Long): String = {
    compact(render(JInt(value)))
  }

  override def jsonDecode(json: String): Long = {
    implicit val formats = DefaultFormats
    parse(json).extract[Long]
  }
}

/**
 * :: DeveloperApi ::
 * Specialized version of [[Param[Boolean]]] for Java.
 */
@DeveloperApi
class BooleanParam(parent: String, name: String, doc: String) // No need for isValid
  extends Param[Boolean](parent, name, doc) {

  def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)

  /** Creates a param pair with the given value (for Java). */
  override def w(value: Boolean): ParamPair[Boolean] = super.w(value)

  override def jsonEncode(value: Boolean): String = {
    compact(render(JBool(value)))
  }

  override def jsonDecode(json: String): Boolean = {
    implicit val formats = DefaultFormats
    parse(json).extract[Boolean]
  }
}

/**
 * :: DeveloperApi ::
 * Specialized version of [[Param[Array[String]]]] for Java.
 */
@DeveloperApi
class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array[String] => Boolean)
  extends Param[Array[String]](parent, name, doc, isValid) {

  def this(parent: Params, name: String, doc: String) =
    this(parent, name, doc, ParamValidators.alwaysTrue)

  /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
  def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray)

  override def jsonEncode(value: Array[String]): String = {
    import org.json4s.JsonDSL._
    compact(render(value.toSeq))
  }

  override def jsonDecode(json: String): Array[String] = {
    implicit val formats = DefaultFormats
    parse(json).extract[Seq[String]].toArray
  }
}

/**
 * :: DeveloperApi ::
 * Specialized version of [[Param[Array[Double]]]] for Java.
 */
@DeveloperApi
class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array[Double] => Boolean)
  extends Param[Array[Double]](parent, name, doc, isValid) {

  def this(parent: Params, name: String, doc: String) =
    this(parent, name, doc, ParamValidators.alwaysTrue)

  /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
  def w(value: java.util.List[java.lang.Double]): ParamPair[Array[Double]] =
    w(value.asScala.map(_.asInstanceOf[Double]).toArray)

  override def jsonEncode(value: Array[Double]): String = {
    import org.json4s.JsonDSL._
    compact(render(value.toSeq.map(DoubleParam.jValueEncode)))
  }

  override def jsonDecode(json: String): Array[Double] = {
    parse(json) match {
      case JArray(values) =>
        values.map(DoubleParam.jValueDecode).toArray
      case _ =>
        throw new IllegalArgumentException(s"Cannot decode $json to Array[Double].")
    }
  }
}

/**
 * :: DeveloperApi ::
 * Specialized version of [[Param[Array[Int]]]] for Java.
 */
@DeveloperApi
class IntArrayParam(parent: Params, name: String, doc: String, isValid: Array[Int] => Boolean)
  extends Param[Array[Int]](parent, name, doc, isValid) {

  def this(parent: Params, name: String, doc: String) =
    this(parent, name, doc, ParamValidators.alwaysTrue)

  /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
  def w(value: java.util.List[java.lang.Integer]): ParamPair[Array[Int]] =
    w(value.asScala.map(_.asInstanceOf[Int]).toArray)

  override def jsonEncode(value: Array[Int]): String = {
    import org.json4s.JsonDSL._
    compact(render(value.toSeq))
  }

  override def jsonDecode(json: String): Array[Int] = {
    implicit val formats = DefaultFormats
    parse(json).extract[Seq[Int]].toArray
  }
}

/**
 * :: Experimental ::
 * A param and its value.
 */
@Since("1.2.0")
@Experimental
case class ParamPair[T] @Since("1.2.0") (
    @Since("1.2.0") param: Param[T],
    @Since("1.2.0") value: T) {
  // This is *the* place Param.validate is called.  Whenever a parameter is specified, we should
  // always construct a ParamPair so that validate is called.
  param.validate(value)
}

/**
 * :: DeveloperApi ::
 * Trait for components that take parameters. This also provides an internal param map to store
 * parameter values attached to the instance.
 */
@DeveloperApi
trait Params extends Identifiable with Serializable {

  /**
   * Returns all params sorted by their names. The default implementation uses Java reflection to
   * list all public methods that have no arguments and return [[Param]].
   *
   * Note: Developer should not use this method in constructor because we cannot guarantee that
   * this variable gets initialized before other params.
   */
  lazy val params: Array[Param[_]] = {
    val methods = this.getClass.getMethods
    methods.filter { m =>
        Modifier.isPublic(m.getModifiers) &&
          classOf[Param[_]].isAssignableFrom(m.getReturnType) &&
          m.getParameterTypes.isEmpty
      }.sortBy(_.getName)
      .map(m => m.invoke(this).asInstanceOf[Param[_]])
  }

  /**
   * Validates parameter values stored internally.
   * Raise an exception if any parameter value is invalid.
   *
   * This only needs to check for interactions between parameters.
   * Parameter value checks which do not depend on other parameters are handled by
   * [[Param.validate()]].  This method does not handle input/output column parameters;
   * those are checked during schema validation.
   * @deprecated Will be removed in 2.1.0. All the checks should be merged into transformSchema
   */
  @deprecated("Will be removed in 2.1.0. Checks should be merged into transformSchema.", "2.0.0")
  def validateParams(): Unit = {
    // Do nothing by default.  Override to handle Param interactions.
  }

  /**
   * Explains a param.
   * @param param input param, must belong to this instance.
   * @return a string that contains the input param name, doc, and optionally its default value and
   *         the user-supplied value
   */
  def explainParam(param: Param[_]): String = {
    shouldOwn(param)
    val valueStr = if (isDefined(param)) {
      val defaultValueStr = getDefault(param).map("default: " + _)
      val currentValueStr = get(param).map("current: " + _)
      (defaultValueStr ++ currentValueStr).mkString("(", ", ", ")")
    } else {
      "(undefined)"
    }
    s"${param.name}: ${param.doc} $valueStr"
  }

  /**
   * Explains all params of this instance.
   * @see [[explainParam()]]
   */
  def explainParams(): String = {
    params.map(explainParam).mkString("\n")
  }

  /** Checks whether a param is explicitly set. */
  final def isSet(param: Param[_]): Boolean = {
    shouldOwn(param)
    paramMap.contains(param)
  }

  /** Checks whether a param is explicitly set or has a default value. */
  final def isDefined(param: Param[_]): Boolean = {
    shouldOwn(param)
    defaultParamMap.contains(param) || paramMap.contains(param)
  }

  /** Tests whether this instance contains a param with a given name. */
  def hasParam(paramName: String): Boolean = {
    params.exists(_.name == paramName)
  }

  /** Gets a param by its name. */
  def getParam(paramName: String): Param[Any] = {
    params.find(_.name == paramName).getOrElse {
      throw new NoSuchElementException(s"Param $paramName does not exist.")
    }.asInstanceOf[Param[Any]]
  }

  /**
   * Sets a parameter in the embedded param map.
   */
  final def set[T](param: Param[T], value: T): this.type = {
    set(param -> value)
  }

  /**
   * Sets a parameter (by name) in the embedded param map.
   */
  protected final def set(param: String, value: Any): this.type = {
    set(getParam(param), value)
  }

  /**
   * Sets a parameter in the embedded param map.
   */
  protected final def set(paramPair: ParamPair[_]): this.type = {
    shouldOwn(paramPair.param)
    paramMap.put(paramPair)
    this
  }

  /**
   * Optionally returns the user-supplied value of a param.
   */
  final def get[T](param: Param[T]): Option[T] = {
    shouldOwn(param)
    paramMap.get(param)
  }

  /**
   * Clears the user-supplied value for the input param.
   */
  final def clear(param: Param[_]): this.type = {
    shouldOwn(param)
    paramMap.remove(param)
    this
  }

  /**
   * Gets the value of a param in the embedded param map or its default value. Throws an exception
   * if neither is set.
   */
  final def getOrDefault[T](param: Param[T]): T = {
    shouldOwn(param)
    get(param).orElse(getDefault(param)).getOrElse(
      throw new NoSuchElementException(s"Failed to find a default value for ${param.name}"))
  }

  /** An alias for [[getOrDefault()]]. */
  protected final def $[T](param: Param[T]): T = getOrDefault(param)

  /**
   * Sets a default value for a param.
   * @param param  param to set the default value. Make sure that this param is initialized before
   *               this method gets called.
   * @param value  the default value
   */
  protected final def setDefault[T](param: Param[T], value: T): this.type = {
    defaultParamMap.put(param -> value)
    this
  }

  /**
   * Sets default values for a list of params.
   *
   * Note: Java developers should use the single-parameter [[setDefault()]].
   *       Annotating this with varargs can cause compilation failures due to a Scala compiler bug.
   *       See SPARK-9268.
   *
   * @param paramPairs  a list of param pairs that specify params and their default values to set
   *                    respectively. Make sure that the params are initialized before this method
   *                    gets called.
   */
  protected final def setDefault(paramPairs: ParamPair[_]*): this.type = {
    paramPairs.foreach { p =>
      setDefault(p.param.asInstanceOf[Param[Any]], p.value)
    }
    this
  }

  /**
   * Gets the default value of a parameter.
   */
  final def getDefault[T](param: Param[T]): Option[T] = {
    shouldOwn(param)
    defaultParamMap.get(param)
  }

  /**
   * Tests whether the input param has a default value set.
   */
  final def hasDefault[T](param: Param[T]): Boolean = {
    shouldOwn(param)
    defaultParamMap.contains(param)
  }

  /**
   * Creates a copy of this instance with the same UID and some extra params.
   * Subclasses should implement this method and set the return type properly.
   *
   * @see [[defaultCopy()]]
   */
  def copy(extra: ParamMap): Params

  /**
   * Default implementation of copy with extra params.
   * It tries to create a new instance with the same UID.
   * Then it copies the embedded and extra parameters over and returns the new instance.
   */
  protected final def defaultCopy[T <: Params](extra: ParamMap): T = {
    val that = this.getClass.getConstructor(classOf[String]).newInstance(uid)
    copyValues(that, extra).asInstanceOf[T]
  }

  /**
   * Extracts the embedded default param values and user-supplied values, and then merges them with
   * extra values from input into a flat param map, where the latter value is used if there exist
   * conflicts, i.e., with ordering: default param values < user-supplied values < extra.
   */
  final def extractParamMap(extra: ParamMap): ParamMap = {
    defaultParamMap ++ paramMap ++ extra
  }

  /**
   * [[extractParamMap]] with no extra values.
   */
  final def extractParamMap(): ParamMap = {
    extractParamMap(ParamMap.empty)
  }

  /** Internal param map for user-supplied values. */
  private val paramMap: ParamMap = ParamMap.empty

  /** Internal param map for default values. */
  private val defaultParamMap: ParamMap = ParamMap.empty

  /** Validates that the input param belongs to this instance. */
  private def shouldOwn(param: Param[_]): Unit = {
    require(param.parent == uid && hasParam(param.name), s"Param $param does not belong to $this.")
  }

  /**
   * Copies param values from this instance to another instance for params shared by them.
   *
   * This handles default Params and explicitly set Params separately.
   * Default Params are copied from and to [[defaultParamMap]], and explicitly set Params are
   * copied from and to [[paramMap]].
   * Warning: This implicitly assumes that this [[Params]] instance and the target instance
   *          share the same set of default Params.
   *
   * @param to the target instance, which should work with the same set of default Params as this
   *           source instance
   * @param extra extra params to be copied to the target's [[paramMap]]
   * @return the target instance with param values copied
   */
  protected def copyValues[T <: Params](to: T, extra: ParamMap = ParamMap.empty): T = {
    val map = paramMap ++ extra
    params.foreach { param =>
      // copy default Params
      if (defaultParamMap.contains(param) && to.hasParam(param.name)) {
        to.defaultParamMap.put(to.getParam(param.name), defaultParamMap(param))
      }
      // copy explicitly set Params
      if (map.contains(param) && to.hasParam(param.name)) {
        to.set(param.name, map(param))
      }
    }
    to
  }
}

/**
 * :: DeveloperApi ::
 * Java-friendly wrapper for [[Params]].
 * Java developers who need to extend [[Params]] should use this class instead.
 * If you need to extend a abstract class which already extends [[Params]], then that abstract
 * class should be Java-friendly as well.
 */
@DeveloperApi
abstract class JavaParams extends Params

/**
 * :: Experimental ::
 * A param to value map.
 */
@Since("1.2.0")
@Experimental
final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
  extends Serializable {

  /* DEVELOPERS: About validating parameter values
   *   This and ParamPair are the only two collections of parameters.
   *   This class should always create ParamPairs when
   *   specifying new parameter values.  ParamPair will then call Param.validate().
   */

  /**
   * Creates an empty param map.
   */
  @Since("1.2.0")
  def this() = this(mutable.Map.empty)

  /**
   * Puts a (param, value) pair (overwrites if the input param exists).
   */
  @Since("1.2.0")
  def put[T](param: Param[T], value: T): this.type = put(param -> value)

  /**
   * Puts a list of param pairs (overwrites if the input params exists).
   */
  @varargs
  @Since("1.2.0")
  def put(paramPairs: ParamPair[_]*): this.type = {
    paramPairs.foreach { p =>
      map(p.param.asInstanceOf[Param[Any]]) = p.value
    }
    this
  }

  /** Put param pairs with a [[java.util.List]] of values for Python. */
  private[ml] def put(paramPairs: JList[ParamPair[_]]): this.type = {
    put(paramPairs.asScala: _*)
  }

  /**
   * Optionally returns the value associated with a param.
   */
  @Since("1.2.0")
  def get[T](param: Param[T]): Option[T] = {
    map.get(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]]
  }

  /**
   * Returns the value associated with a param or a default value.
   */
  @Since("1.4.0")
  def getOrElse[T](param: Param[T], default: T): T = {
    get(param).getOrElse(default)
  }

  /**
   * Gets the value of the input param or its default value if it does not exist.
   * Raises a NoSuchElementException if there is no value associated with the input param.
   */
  @Since("1.2.0")
  def apply[T](param: Param[T]): T = {
    get(param).getOrElse {
      throw new NoSuchElementException(s"Cannot find param ${param.name}.")
    }
  }

  /**
   * Checks whether a parameter is explicitly specified.
   */
  @Since("1.2.0")
  def contains(param: Param[_]): Boolean = {
    map.contains(param.asInstanceOf[Param[Any]])
  }

  /**
   * Removes a key from this map and returns its value associated previously as an option.
   */
  @Since("1.4.0")
  def remove[T](param: Param[T]): Option[T] = {
    map.remove(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]]
  }

  /**
   * Filters this param map for the given parent.
   */
  @Since("1.2.0")
  def filter(parent: Params): ParamMap = {
    // Don't use filterKeys because mutable.Map#filterKeys
    // returns the instance of collections.Map, not mutable.Map.
    // Otherwise, we get ClassCastException.
    // Not using filterKeys also avoid SI-6654
    val filtered = map.filter { case (k, _) => k.parent == parent.uid }
    new ParamMap(filtered)
  }

  /**
   * Creates a copy of this param map.
   */
  @Since("1.2.0")
  def copy: ParamMap = new ParamMap(map.clone())

  @Since("1.2.0")
  override def toString: String = {
    map.toSeq.sortBy(_._1.name).map { case (param, value) =>
      s"\t${param.parent}-${param.name}: $value"
    }.mkString("{\n", ",\n", "\n}")
  }

  /**
   * Returns a new param map that contains parameters in this map and the given map,
   * where the latter overwrites this if there exist conflicts.
   */
  @Since("1.2.0")
  def ++(other: ParamMap): ParamMap = {
    // TODO: Provide a better method name for Java users.
    new ParamMap(this.map ++ other.map)
  }

  /**
   * Adds all parameters from the input param map into this param map.
   */
  @Since("1.2.0")
  def ++=(other: ParamMap): this.type = {
    // TODO: Provide a better method name for Java users.
    this.map ++= other.map
    this
  }

  /**
   * Converts this param map to a sequence of param pairs.
   */
  @Since("1.2.0")
  def toSeq: Seq[ParamPair[_]] = {
    map.toSeq.map { case (param, value) =>
      ParamPair(param, value)
    }
  }

  /** Java-friendly method for Python API */
  private[ml] def toList: java.util.List[ParamPair[_]] = {
    this.toSeq.asJava
  }

  /**
   * Number of param pairs in this map.
   */
  @Since("1.3.0")
  def size: Int = map.size
}

@Since("1.2.0")
@Experimental
object ParamMap {

  /**
   * Returns an empty param map.
   */
  @Since("1.2.0")
  def empty: ParamMap = new ParamMap()

  /**
   * Constructs a param map by specifying its entries.
   */
  @varargs
  @Since("1.2.0")
  def apply(paramPairs: ParamPair[_]*): ParamMap = {
    new ParamMap().put(paramPairs: _*)
  }
}