Package net.sf.tweety.machinelearning
Class GridSearchParameterLearner<S extends Observation,T extends Category>
- java.lang.Object
-
- net.sf.tweety.machinelearning.ParameterTrainer<S,T>
-
- net.sf.tweety.machinelearning.GridSearchParameterLearner<S,T>
-
- Type Parameters:
S
- the type of observations.T
- the type of categories.
- All Implemented Interfaces:
Trainer<S,T>
public class GridSearchParameterLearner<S extends Observation,T extends Category> extends ParameterTrainer<S,T>
A grid-search approach for learning parameters. For each parameter with I=[l,u] being the boundaries for the parameter value of a given trainer, I is divided intopartitions
number of partitions. For each partition of each parameter the border points are chosen and a new classifier is learned with given parameter combination. From all combinations the combination where the classifier performs best is chosen. Ifdepth
> 1, the process is iterated: after selecting the best interval combination of the parameters, these intervals are again divided and the process is repeateddepth
many times.- Author:
- Matthias Thimm
-
-
Field Summary
Fields Modifier and Type Field Description private int
depth
The depth of the recursion.private int
partitions
The number of partitions of each parameter interval.private ClassificationTester<S,T>
tester
The tester used for measuring the performance of each parameter combination.
-
Constructor Summary
Constructors Constructor Description GridSearchParameterLearner(Trainer<S,T> trainer, ClassificationTester<S,T> tester, int depth, int partitions)
Creates a new grid-search parameter learner with the given arguments.
-
Method Summary
Modifier and Type Method Description private ParameterSet
adjustParameterSet(ParameterSet set, int[] indices, double[] lowerBounds, double[] upperBounds)
Determine for all parameters of the set a new value, determined by partitioning [lowerBound,upperBound] into this.partitions number of sub intervals and then taking the center point of the partition no.private boolean
increment(int[] indices, int maxIdx)
Increments the given array of indices, e.g.ParameterSet
learnParameters(TrainingSet<S,T> trainingSet)
Learns the best parameters of the given trainer for the training set.-
Methods inherited from class net.sf.tweety.machinelearning.ParameterTrainer
getParameterSet, getTrainer, setParameterSet, train, train
-
-
-
-
Field Detail
-
depth
private int depth
The depth of the recursion.
-
partitions
private int partitions
The number of partitions of each parameter interval.
-
tester
private ClassificationTester<S extends Observation,T extends Category> tester
The tester used for measuring the performance of each parameter combination.
-
-
Constructor Detail
-
GridSearchParameterLearner
public GridSearchParameterLearner(Trainer<S,T> trainer, ClassificationTester<S,T> tester, int depth, int partitions)
Creates a new grid-search parameter learner with the given arguments.- Parameters:
trainer
- some trainer.tester
- some classification tester for measuring performance.depth
- the depth of the recursion.partitions
- the number of partitions.
-
-
Method Detail
-
learnParameters
public ParameterSet learnParameters(TrainingSet<S,T> trainingSet)
Description copied from class:ParameterTrainer
Learns the best parameters of the given trainer for the training set.- Specified by:
learnParameters
in classParameterTrainer<S extends Observation,T extends Category>
- Parameters:
trainingSet
- a training set- Returns:
- the best parameters for the training set.
-
adjustParameterSet
private ParameterSet adjustParameterSet(ParameterSet set, int[] indices, double[] lowerBounds, double[] upperBounds)
Determine for all parameters of the set a new value, determined by partitioning [lowerBound,upperBound] into this.partitions number of sub intervals and then taking the center point of the partition no. idx.- Parameters:
set
- a parameter setindices
- indiceslowerBounds
- the lower boundsupperBounds
- the upper bounds- Returns:
- a new parameter set
-
increment
private boolean increment(int[] indices, int maxIdx)
Increments the given array of indices, e.g. (for maxIdx=5) given [0,0,0,0] it returns [1,0,0,0], given [4,2,1,4] it returns [5,2,1,4], given [5,2,1,4] it returns [0,3,1,4], given [5,5,1,4] it returns [0,0,2,4], etc. It returns true iff an overflow occurs, e.g. if [5,5,5,5] is to be incremented.- Parameters:
indices
- an array of ints.maxIdx
- the max index.- Returns:
- "true" iff an overflow occurs
-
-