Pytorch(Python)中的itertools.count()函数

2023-05-16

在看深度强化学习DQN代码时,遇到这段代码,搞了好久都没看明白。
完整代码参考这个博客。

for t in count():
        #count()用法: itertools.count(start=0, step=1)
        #start:序列的开始(默认为0)
        #step:连续数字之间的差(默认为1)
        reward = 0  #设置初始化奖励为0
        m_reward = 0#求和奖励
        # 每m帧完成一次action
        action = select_action(state)#选择动作
        #每四步更新一次奖励
        for i in range(m):
            #与环境交互,选择一个动作之后,获得奖励,并判断是否时最终状态
            _, reward, done, _ = env.step(action.item()) 
            if not done:
                #如果不是终止状态,那么屏幕截屏到next_state_queue
                next_state_queue.append(get_screen())
            else:
                #否则的话,就终止程序
                break
            m_reward += reward#然后累加奖励

        if not done:
            #如果不是终止状态,那么就进入下一个状态,把下一个状态连接到一起,使用tuple,不会被修改
            next_state = torch.cat(tuple(next_state_queue), dim=1)
        else:
            #如果是终止状态,则下一个状态就没有了,获取最终奖励
            next_state = None
            m_reward = 150
        m_reward = torch.tensor([m_reward], device=device)#把奖励转换成张量

        memory.push(state, action, next_state, m_reward)#把计算出来的四个元素集存储到replay buffer中

        state = next_state#把下一个状态转为当前状态
        optimize_model()#开始优化模型

这个for循环的使用方式说实话我是真的不明白。

for t in count():

能找到关于count()的信息是上面的import部分

from itertools import count

然后我找了好多博客,最后这个博客给我讲明白了。

itertools.count(start,step)函数的意思是创建一个从start开始每次的步长是step的无穷序列
当count()括号里为空时,表示从0开始,每次步长为1.

我们再回到实际的代码环境中。
这段代码出现在迭代训练阶段
第一个for循环时迭代次数
在这个训练开始时,我们会使用random_start()函数计算出done, state_queue, next_state_queue,即状态的状态(终止状态和非终止状态),当前状态序列和下一个状态序列。
然后首先就要判断当前状态时是否是终止状态,不是终止状态就继续我们说的这个for循环。
那么第二个这个for循环为什么时无限制循环的呢?

for t in count():

这个循环开始,首先就是初始化奖励和初始化累计奖励

reward = 0  #设置初始化奖励为0
m_reward = 0#求和奖励

然后使用动作选择函数选择算法需要执行的动作

action = select_action(state)#选择动作

下面就开始第三个循环了

for i in range(m):

m=4,因为每个状态有四张图像
这个循环的第一行代码是

_, reward, done, _ = env.step(action.item())

作用就是将上面选择的动作输入到环境中,然后环境会给出奖励和判断该奖励是否是终止状态。

            if not done:
                #如果不是终止状态,那么屏幕截屏到next_state_queue
                next_state_queue.append(get_screen())
            else:
                #否则的话,就终止程序
                break
            m_reward += reward#然后累加奖励

然后就开始判断该状态是否是终止状态,如果是终止状态就跳出该循环,不是的话就把当前屏幕截屏添加到next_state_queue序列中。
m=4,所以要执行四次。然后把这四次采集到的图像存储到序列中,需要提到的是,在这个for循环中,agent所使用的动作是一样的。
采集到四张图像之后,这个循环结束。
然后开始金鱼不判断状态是否结束了

        if not done:
            #如果不是终止状态,那么就进入下一个状态,把下一个状态连接到一起,使用tuple,不会被修改
            next_state = torch.cat(tuple(next_state_queue), dim=1)
        else:
            #如果是终止状态,则下一个状态就没有了,获取最终奖励
            next_state = None
            m_reward = 150

如果没有结束,就把这个next_state_queue中的图像拼接cat起来,
如果是终止状态,那么提示没有下一个状态,给出奖励。
然后进行下一步

        m_reward = torch.tensor([m_reward], device=device)#把奖励转换成张量
        memory.push(state, action, next_state, m_reward)#把计算出来的四个元素集存储到replay buffer中
        state = next_state#把下一个状态转为当前状态
        optimize_model()#开始优化模型

这个动作执行结束后,把奖励转成张量,然后把transition四元数存储到replay buffer中。
然后更新当前状态。
开始优化模型。
在开始判断状态是否终止
并保存训练过程数据和更新网络模型参数
保存模型

        if done:
            episode_durations.append(t + 1)
            plot_durations()
            break

    # 更新目标网络,复制DQN中的所有权重和偏置
    if i_episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())
        if i_episode % 1000 ==0:
            torch.save(policy_net.state_dict(), 'weights/policy_net_weights_{0}.pth'.format(i_episode))

当我把所有的循环看完之后,终于明白。这个无限循环的for循环是为了收集replay buffer中的transition。我们设置replay buffer的容量为100000,但是由于agent’与环境交互的不可知性导致我们知道到底要多少步才能完成。所以使用了这个循环。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch(Python)中的itertools.count()函数 的相关文章

随机推荐