java - 如何有效地更新 EnumeratedDistribution 实例中的概率?

标签 java android distribution q-learning

问题总结

有没有办法更新 EnumeratedIntegerDistribution 类的现有实例中的概率?不创建一个全新的实例?

背景

我正在尝试使用 Android 手机实现简化的 Q 学习风格演示。我需要在学习过程中的每个循环中更新每个项目的概率。目前,我无法从我的 enumeratedIntegerDistribution 实例中找到任何可让我重置|更新|修改这些概率的方法。因此,我能看到的唯一方法是在每个循环中创建一个新的 EnumeratedIntegerDistribution 实例。请记住,每个循环只有 20 毫秒长,据我了解,与创建一个实例并更新现有实例中的值相比,这将是非常低效的内存。是否没有标准的集合式方法来更新这些概率?如果没有,是否有推荐的解决方法(即使用不同的类、创建我自己的类、覆盖某些内容以使其可访问等?)

跟进将是这个问题是否是一个没有实际意义的努力。通过尝试在每个循环中避免这个新实例,编译后的代码实际上会提高/降低效率吗? (我的知识不够了解编译器将如何处理此类事情)。

代码

下面是一个最小的例子:

package com.example.mypackage.learning;  
  
import android.app.Activity;  
import android.os.Bundle;  
import org.apache.commons.math3.distribution.EnumeratedIntegerDistribution;  
  
  
public class Qlearning extends Activity {  
  
    private int selectedAction;  
    private int[] actions = {0, 1, 2};  
    private double[] weights = {1.0, 1.0, 1.0};  
    private double[] qValues = {1.0, 1.0, 1.0};  
    private double qValuesSum;  
    EnumeratedIntegerDistribution enumeratedIntegerDistribution = new EnumeratedIntegerDistribution(actions, weights);  
    private final double alpha = 0.001;  
    int action;  
    double reward;  
  
    @Override  
    protected void onCreate(Bundle savedInstanceState) {  
        super.onCreate(savedInstanceState);  
        while(true){  
            action = determineAction();  
            reward = determineReward();  
            learn(action, reward);  
        }  
    }  
      
    public void learn(int action, double reward) {  
        qValues[selectedAction] = (alpha * reward) + ((1.0 - alpha) * qValues[selectedAction]);  
        qValuesSum = 0;  
        for (int i = 0; i < qValues.length; i++){  
            qValuesSum += Math.exp(qValues[i]);  
        }  
        weights[selectedAction] = Math.exp(qValues[selectedAction]) / qValuesSum;  
        // *** This seems inefficient ***  
        EnumeratedIntegerDistribution enumeratedIntegerDistribution = new EnumeratedIntegerDistribution(actions, weights);  
    }  
}

请不要关注缺少 determineAction()determineReward() 方法,因为这只是一个最小的示例。如果您想要一个工作示例,您可以很容易地在其中插入固定值(例如 1 和 1.5)。

此外,我很清楚无限 while 循环对于 GUI 来说会很麻烦,但同样,我只是想减少我必须在此处显示的代码以表达观点。

编辑:

作为对评论的回应,我在下面发布了我对类似类(class)的评论。请注意,我已经一年多没用过它了,可能会坏掉。仅供引用:

public class ActionDistribution{

    private double reward = 0;
    private double[] weights = {0.34, 0.34, 0.34};
    private double[] qValues = {0.1, 0.1, 0.1};
    private double learningRate = 0.1;
    private double temperature = 1.0;
    private int selectedAction;

    public ActionDistribution(){}

    public ActionDistribution(double[] weights, double[] qValues, double learningRate, double temperature){
        this.weights = weights;
        this.qValues = qValues;
        this.learningRate = learningRate;
        this.temperature = temperature;
    }

    public int actionSelect(){

        double sumOfWeights = 0;
        for (double weight: weights){
            sumOfWeights = sumOfWeights + weight;
        }
        double randNum = Math.random() * sumOfWeights;
        double selector = 0;
        int iterator = -1;

        while (selector < randNum){
            try {
                iterator++;
                selector = selector + weights[iterator];
            }catch (ArrayIndexOutOfBoundsException e){
                Log.e("abcvlib", "weight index bound exceeded. randNum was greater than the sum of all weights. This can happen if the sum of all weights is less than 1.");
            }
        }
        // Assigning this as a read-only value to pass between threads.
        this.selectedAction = iterator;

        // represents the action to be selected
        return iterator;
    }

    public double[] getWeights(){
        return weights;
    }

    public double[] getqValues(){
        return qValues;
    }

    public double getQValue(int action){
        return qValues[action];
    }

    public double getTemperature(){
        return temperature;
    }


    public int getSelectedAction() {
        return selectedAction;
    }

    public void setWeights(double[] weights) {
        this.weights = weights;
    }

    public void setQValue(int action, double qValue) {
        this.qValues[action] = qValue;
    }

    public void updateValues(double reward, int action){

        double qValuePrev = getQValue(action);

        // update qValues due to current reward
        setQValue(action,(learningRate * reward) + ((1.0 - learningRate) * qValuePrev));
        // update weights from new qValues

        double qValuesSum = 0;
        for (double qValue : getqValues()) {
            qValuesSum += Math.exp(temperature * qValue);
        }

        // update weights
        for (int i = 0; i < getWeights().length; i++){
            getWeights()[i] = Math.exp(temperature * getqValues()[i]) / qValuesSum;
        }
    }

    public double getReward() {
        return reward;
    }

    public void setReward(double reward) {
        this.reward = reward;
    }
}

最佳答案

很遗憾,无法更新现有的 EnumeratedIntegerDistribution。我过去遇到过类似的问题,每次需要更新机会时我都重新创建了实例。

我不会太担心内存分配,因为这些都是短期对象。这些是您不应该担心的微优化。

在我的项目中,我确实使用接口(interface)实现了一种更简洁的方式来创建这些 EnumeratedDistribution 的实例。类。

这不是直接的答案,但可能会引导您朝着正确的方向前进。

public class DistributedProbabilityGeneratorBuilder<T extends DistributedProbabilityGeneratorBuilder.ProbableItem> {

    private static final DistributedProbabilityGenerator EMPTY = () -> {
        throw new UnsupportedOperationException("Not supported");
    };

    private final Map<Integer, T> distribution = new HashMap<>();

    private DistributedProbabilityGeneratorBuilder() {
    }

    public static <T extends ProbableItem> DistributedProbabilityGeneratorBuilder<T> newBuilder() {
        return new DistributedProbabilityGeneratorBuilder<>();
    }

    public DistributedProbabilityGenerator build() {
        return build(ProbableItem::getChances);
    }

    /**
     * Returns a new instance of probability generator at every call.
     * @param chanceChangeFunction - Function to modify existing chances
     */
    public DistributedProbabilityGenerator build(Function<T, Double> chanceChangeFunction) {
        if (distribution.isEmpty()) {
            return EMPTY;
        } else {
            return new NonEmptyProbabilityGenerator(createPairList(chanceChangeFunction));
        }
    }

    private List<Pair<Integer, Double>> createPairList(Function<T, Double> chanceChangeFunction) {
        return distribution.entrySet().stream()
                .map(entry -> Pair.create(entry.getKey(), chanceChangeFunction.apply(entry.getValue())))
                .collect(Collectors.toList());
    }

    public DistributedProbabilityGeneratorBuilder<T> add(int id, T item) {
        if (distribution.containsKey(id)) {
            throw new IllegalArgumentException("Id " + id + " already present.");
        }

        this.distribution.put(id, item);
        return this;
    }

    public interface ProbableItem {

        double getChances();
    }

    public interface DistributedProbabilityGenerator {

        int generateId();
    }

    public static class NonEmptyProbabilityGenerator implements DistributedProbabilityGenerator {

        private final EnumeratedDistribution<Integer> enumeratedDistribution;

        NonEmptyProbabilityGenerator(List<Pair<Integer, Double>> pairs) {
            this.enumeratedDistribution = new EnumeratedDistribution<>(pairs);
        }

        @Override
        public int generateId() {
            return enumeratedDistribution.sample();
        }
    }

    public static ProbableItem ofDouble(double chances) {
        return () -> chances;
    }
}

注意 - 我正在使用 EnumeratedDistribution<Integer> .您可以轻松地将其更改为 EnumuratedIntegerDistribution .

我使用上述类的方式如下。

DistributedProbabilityGenerator distributedProbabilityGenerator = DistributedProbabilityGeneratorBuilder.newBuilder()
                .add(0, ofDouble(10))
                .add(1, ofDouble(45))
                .add(2, ofDouble(45))
                .build();

int generatedObjectId = distributedProbabilityGenerator.generateId();

同样,这不是对您问题的直接回答,而是更多地指导您如何以更好的方式使用这些类。

关于java - 如何有效地更新 EnumeratedDistribution 实例中的概率?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58796591/

相关文章:

java - android ksoap2 java.lang.IllegalArgumentException : size <= 0

android - 使用 android.tools.lint 执行 JetifyTransform 失败

c# - 这些库可以与我的 C# 应用程序一起分发吗?

python - 在 matplotlib 中将 hist2d 输出转换为轮廓

java - Maven Tycho 插件和 Eclipse Babel

java - 使用spring存储过程调用oracle存储过程

Android屏幕尺寸全屏

r - 如何在R中生成具有指定对数正态分布的随机数?

java - Swt FormToolkit焦点问题

java - 环境变量不支持UTF-8?