Class GridSearchParameterLearner<S extends Observation,​T extends Category>

  • Type Parameters:
    S - the type of observations.
    T - the type of categories.
    All Implemented Interfaces:
    Trainer<S,​T>

    public class GridSearchParameterLearner<S extends Observation,​T extends Category>
    extends ParameterTrainer<S,​T>
    A grid-search approach for learning parameters. For each parameter with I=[l,u] being the boundaries for the parameter value of a given trainer, I is divided into partitions number of partitions. For each partition of each parameter the border points are chosen and a new classifier is learned with given parameter combination. From all combinations the combination where the classifier performs best is chosen. If depth > 1, the process is iterated: after selecting the best interval combination of the parameters, these intervals are again divided and the process is repeated depth many times.
    Author:
    Matthias Thimm
    • Field Summary

      Fields 
      Modifier and Type Field Description
      private int depth
      The depth of the recursion.
      private int partitions
      The number of partitions of each parameter interval.
      private ClassificationTester<S,​T> tester
      The tester used for measuring the performance of each parameter combination.
    • Method Summary

      Modifier and Type Method Description
      private ParameterSet adjustParameterSet​(ParameterSet set, int[] indices, double[] lowerBounds, double[] upperBounds)
      Determine for all parameters of the set a new value, determined by partitioning [lowerBound,upperBound] into this.partitions number of sub intervals and then taking the center point of the partition no.
      private boolean increment​(int[] indices, int maxIdx)
      Increments the given array of indices, e.g.
      ParameterSet learnParameters​(TrainingSet<S,​T> trainingSet)
      Learns the best parameters of the given trainer for the training set.
      • Methods inherited from class java.lang.Object

        clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
    • Field Detail

      • depth

        private int depth
        The depth of the recursion.
      • partitions

        private int partitions
        The number of partitions of each parameter interval.
    • Constructor Detail

      • GridSearchParameterLearner

        public GridSearchParameterLearner​(Trainer<S,​T> trainer,
                                          ClassificationTester<S,​T> tester,
                                          int depth,
                                          int partitions)
        Creates a new grid-search parameter learner with the given arguments.
        Parameters:
        trainer - some trainer.
        tester - some classification tester for measuring performance.
        depth - the depth of the recursion.
        partitions - the number of partitions.
    • Method Detail

      • adjustParameterSet

        private ParameterSet adjustParameterSet​(ParameterSet set,
                                                int[] indices,
                                                double[] lowerBounds,
                                                double[] upperBounds)
        Determine for all parameters of the set a new value, determined by partitioning [lowerBound,upperBound] into this.partitions number of sub intervals and then taking the center point of the partition no. idx.
        Parameters:
        set - a parameter set
        indices - indices
        lowerBounds - the lower bounds
        upperBounds - the upper bounds
        Returns:
        a new parameter set
      • increment

        private boolean increment​(int[] indices,
                                  int maxIdx)
        Increments the given array of indices, e.g. (for maxIdx=5) given [0,0,0,0] it returns [1,0,0,0], given [4,2,1,4] it returns [5,2,1,4], given [5,2,1,4] it returns [0,3,1,4], given [5,5,1,4] it returns [0,0,2,4], etc. It returns true iff an overflow occurs, e.g. if [5,5,5,5] is to be incremented.
        Parameters:
        indices - an array of ints.
        maxIdx - the max index.
        Returns:
        "true" iff an overflow occurs