package org.tweetyproject.logics.rpcl.reasoner;

import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tweetyproject.commons.Interpretation;
import org.tweetyproject.commons.ModelProvider;
import org.tweetyproject.commons.QuantitativeReasoner;
import org.tweetyproject.logics.commons.syntax.Constant;
import org.tweetyproject.logics.commons.syntax.Predicate;
import org.tweetyproject.logics.fol.semantics.HerbrandBase;
import org.tweetyproject.logics.fol.semantics.HerbrandInterpretation;
import org.tweetyproject.logics.fol.syntax.FolFormula;
import org.tweetyproject.logics.fol.syntax.FolSignature;
import org.tweetyproject.logics.rpcl.semantics.CondensedProbabilityDistribution;
import org.tweetyproject.logics.rpcl.semantics.ReferenceWorld;
import org.tweetyproject.logics.rpcl.semantics.RpclProbabilityDistribution;
import org.tweetyproject.logics.rpcl.semantics.RpclSemantics;
import org.tweetyproject.logics.rpcl.syntax.RelationalProbabilisticConditional;
import org.tweetyproject.logics.rpcl.syntax.RpclBeliefSet;
import org.tweetyproject.math.GeneralMathException;
import org.tweetyproject.math.equation.Equation;
import org.tweetyproject.math.opt.ProblemInconsistentException;
import org.tweetyproject.math.opt.problem.OptimizationProblem;
import org.tweetyproject.math.opt.solver.Solver;
import org.tweetyproject.math.probability.Probability;
import org.tweetyproject.math.term.FloatConstant;
import org.tweetyproject.math.term.FloatVariable;
import org.tweetyproject.math.term.IntegerConstant;
import org.tweetyproject.math.term.Logarithm;
import org.tweetyproject.math.term.Product;
import org.tweetyproject.math.term.Term;
import org.tweetyproject.math.term.Variable;

/* loaded from: input_file:org/tweetyproject/logics/rpcl/reasoner/RpclMeReasoner.class */
public class RpclMeReasoner implements QuantitativeReasoner<RpclBeliefSet, FolFormula>, ModelProvider<RelationalProbabilisticConditional, RpclBeliefSet, RpclProbabilityDistribution<?>> {
    private static Logger log = LoggerFactory.getLogger(RpclMeReasoner.class);
    public static final int STANDARD_INFERENCE = 1;
    public static final int LIFTED_INFERENCE = 2;
    private RpclSemantics semantics;
    private int inferenceType;

    public RpclMeReasoner(RpclSemantics rpclSemantics, int i) {
        if (i != 1 && i != 2) {
            log.error("The inference type must be either 'standard' or 'lifted'.");
            throw new IllegalArgumentException("The inference type must be either 'standard' or 'lifted'.");
        }
        this.semantics = rpclSemantics;
        this.inferenceType = i;
    }

    public RpclMeReasoner(RpclSemantics rpclSemantics) {
        this(rpclSemantics, 1);
    }

    public int getInferenceType() {
        return this.inferenceType;
    }

    public Double query(RpclBeliefSet rpclBeliefSet, FolFormula folFormula, FolSignature folSignature) {
        return getModel(rpclBeliefSet, folSignature).probability(folFormula).getValue();
    }

    @Override // org.tweetyproject.commons.Reasoner
    public Double query(RpclBeliefSet rpclBeliefSet, FolFormula folFormula) {
        return query(rpclBeliefSet, folFormula, (FolSignature) rpclBeliefSet.getMinimalSignature());
    }

    @Override // org.tweetyproject.commons.ModelProvider
    public Collection<RpclProbabilityDistribution<?>> getModels(RpclBeliefSet rpclBeliefSet) {
        HashSet hashSet = new HashSet();
        hashSet.add(getModel(rpclBeliefSet));
        return hashSet;
    }

    @Override // org.tweetyproject.commons.ModelProvider
    public RpclProbabilityDistribution<?> getModel(RpclBeliefSet rpclBeliefSet) {
        return getModel(rpclBeliefSet, (FolSignature) rpclBeliefSet.getMinimalSignature());
    }

    public RpclProbabilityDistribution<?> getModel(RpclBeliefSet rpclBeliefSet, FolSignature folSignature) {
        if (!rpclBeliefSet.getMinimalSignature().isSubSignature(folSignature)) {
            log.error("Signature must be super-signature of the belief set's signature.");
            throw new IllegalArgumentException("Signature must be super-signature of the belief set's signature.");
        }
        if (this.inferenceType == 2) {
            Iterator<Predicate> it = ((FolSignature) rpclBeliefSet.getMinimalSignature()).getPredicates().iterator();
            while (it.hasNext()) {
                if (it.next().getArity() > 1) {
                    log.error("Lifted inference only applicable for signatures containing only unary predicates.");
                    throw new IllegalArgumentException("Lifted inference only applicable for signatures containing only unary predicates.");
                }
            }
        }
        log.info("Computing ME-distribution using \"" + this.semantics.toString() + "\" and " + (this.inferenceType == 2 ? "lifted" : "standard") + " inference for the knowledge base " + rpclBeliefSet.toString() + ".");
        log.info("Constructing optimization problem for finding the ME-distribution.");
        if (this.inferenceType != 2) {
            Set<HerbrandInterpretation> allHerbrandInterpretations = new HerbrandBase(folSignature).getAllHerbrandInterpretations();
            HashMap hashMap = new HashMap();
            if (rpclBeliefSet.size() == 0) {
                return RpclProbabilityDistribution.getUniformDistribution(this.semantics, folSignature);
            }
            int i = 0;
            HashSet hashSet = new HashSet();
            Term term = null;
            for (HerbrandInterpretation herbrandInterpretation : allHerbrandInterpretations) {
                int i2 = i;
                i++;
                FloatVariable floatVariable = new FloatVariable("X" + i2, CMAESOptimizer.DEFAULT_STOPFITNESS, 1.0d);
                hashMap.put(herbrandInterpretation, floatVariable);
                term = term == null ? floatVariable : term.add(floatVariable);
            }
            hashSet.add(new Equation(term, new FloatConstant(1.0f)));
            Iterator<RelationalProbabilisticConditional> it2 = rpclBeliefSet.iterator();
            while (it2.hasNext()) {
                hashSet.add(this.semantics.getSatisfactionStatement(it2.next(), folSignature, hashMap));
            }
            OptimizationProblem optimizationProblem = new OptimizationProblem(1);
            optimizationProblem.addAll(hashSet);
            Term term2 = null;
            for (Interpretation interpretation : hashMap.keySet()) {
                Product mult = new IntegerConstant(-1).mult(((FloatVariable) hashMap.get(interpretation)).mult(new Logarithm((Term) hashMap.get(interpretation))));
                term2 = term2 == null ? mult : term2.add(mult);
            }
            optimizationProblem.setTargetFunction(term2);
            try {
                Map<Variable, Term> solve = Solver.getDefaultGeneralSolver().solve(optimizationProblem);
                RpclProbabilityDistribution<?> rpclProbabilityDistribution = new RpclProbabilityDistribution<>(this.semantics, folSignature);
                for (Interpretation interpretation2 : hashMap.keySet()) {
                    rpclProbabilityDistribution.put((RpclProbabilityDistribution<?>) interpretation2, new Probability(Double.valueOf(solve.get(hashMap.get(interpretation2)).value().doubleValue())));
                }
                return rpclProbabilityDistribution;
            } catch (GeneralMathException e) {
                log.error("The knowledge base " + rpclBeliefSet + " is inconsistent.");
                throw new ProblemInconsistentException();
            }
        }
        Set<Set<Constant>> equivalenceClasses = rpclBeliefSet.getEquivalenceClasses(folSignature);
        Set<ReferenceWorld> enumerateReferenceWorlds = ReferenceWorld.enumerateReferenceWorlds(folSignature.getPredicates(), equivalenceClasses);
        HashMap hashMap2 = new HashMap();
        if (rpclBeliefSet.size() == 0) {
            return CondensedProbabilityDistribution.getUniformDistribution(this.semantics, folSignature, equivalenceClasses);
        }
        int i3 = 0;
        HashSet hashSet2 = new HashSet();
        Term term3 = null;
        for (ReferenceWorld referenceWorld : enumerateReferenceWorlds) {
            int i4 = i3;
            i3++;
            FloatVariable floatVariable2 = new FloatVariable("X" + i4, CMAESOptimizer.DEFAULT_STOPFITNESS, 1.0d);
            hashMap2.put(referenceWorld, floatVariable2);
            Product mult2 = new FloatConstant(referenceWorld.spanNumber().intValue()).mult(floatVariable2);
            term3 = term3 == null ? mult2 : term3.add(mult2);
        }
        hashSet2.add(new Equation(term3, new FloatConstant(1.0f)));
        Iterator<RelationalProbabilisticConditional> it3 = rpclBeliefSet.iterator();
        while (it3.hasNext()) {
            hashSet2.add(this.semantics.getSatisfactionStatement(it3.next(), folSignature, hashMap2));
        }
        OptimizationProblem optimizationProblem2 = new OptimizationProblem(1);
        optimizationProblem2.addAll(hashSet2);
        Term term4 = null;
        for (Interpretation interpretation3 : hashMap2.keySet()) {
            Product mult3 = new IntegerConstant(-((ReferenceWorld) interpretation3).spanNumber().intValue()).mult(((FloatVariable) hashMap2.get(interpretation3)).mult(new Logarithm((Term) hashMap2.get(interpretation3))));
            term4 = term4 == null ? mult3 : term4.add(mult3);
        }
        optimizationProblem2.setTargetFunction(term4);
        try {
            Map<Variable, Term> solve2 = Solver.getDefaultGeneralSolver().solve(optimizationProblem2);
            CondensedProbabilityDistribution condensedProbabilityDistribution = new CondensedProbabilityDistribution(this.semantics, folSignature);
            for (Interpretation interpretation4 : hashMap2.keySet()) {
                condensedProbabilityDistribution.put((CondensedProbabilityDistribution) interpretation4, new Probability(Double.valueOf(solve2.get(hashMap2.get(interpretation4)).value().doubleValue())));
            }
            return condensedProbabilityDistribution;
        } catch (GeneralMathException e2) {
            log.error("The knowledge base " + rpclBeliefSet + " is inconsistent.");
            throw new ProblemInconsistentException();
        }
    }
}
