• Najnowsze pytania
  • Bez odpowiedzi
  • Zadaj pytanie
  • Kategorie
  • Tagi
  • Zdobyte punkty
  • Ekipa ninja
  • IRC
  • FAQ
  • Regulamin
  • Książki warte uwagi

Gdzie jest błąd? - Program do uczenia maszynowego

Object Storage Arubacloud
0 głosów
314 wizyt
pytanie zadane 3 lipca 2021 w Python przez Aqua 4 Gaduła (3,220 p.)

Od pewnego czasu próbuję zaimplementować algorytm uczenie przez wzmacnianie "CrossEntropy" (wyjaśniony tutaj). Problem polega na tym, że agent nie uczy się niczego, a po kilku iteracjach sieć neuronowa zaczyna zwracać zawsze tą samą wartość. Debuguję ten program od kilku godzin, proszę pomóżcie.

Kod:

commom.py:

# common.py

from collections import namedtuple
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F


EpisodeStep = namedtuple("EpisodeStep", field_names=("state", "action", "reward", "next_state", "done", "info"))
Episode = namedtuple("Episode", field_names=("total_reward", "history"))


class LinearNetwork(nn.Module):
    def __init__(self, input_shape, hidden_size, output_shape):
        super(LinearNetwork, self).__init__()
        self.nn = nn.Sequential(
            nn.Linear(in_features=input_shape, out_features=hidden_size), nn.ELU(),
            nn.Linear(in_features=hidden_size, out_features=output_shape))

    def forward(self, x):
        return self.nn(x)


class PolicyAgent(object):
    def __init__(self, net, env, device):
        super(PolicyAgent, self).__init__()
        self.net = net
        self.env = env
        self.device = device

    def __call__(self, observation):
        probs = F.softmax(self.net(
            torch.FloatTensor(observation[None, ...]).to(self.device))[0], dim=0).data.cpu().numpy()
        action = np.random.choice(probs.shape[0], p=probs)
        next_state, rewards, done, info = self.env.step(action)
        return EpisodeStep(state=observation, action=action, reward=rewards,
                           next_state=next_state, done=done, info=info)


class DataCollector(object):
    def __init__(self, agent, env):
        self.agent = agent
        self.env = env

    def iterate_steps(self):
        while True:
            obs = self.env.reset()
            done = False
            while not done:
                step = self.agent(obs)
                obs, done = step.next_state, step.done
                yield step

    def iterate_episodes(self):
        history, total_reward = [], 0.
        for step in self.iterate_steps():
            history.append(step)
            total_reward += step.reward
            if step.done:
                yield Episode(total_reward=total_reward, history=history)
                history.clear()
                total_reward = 0.


class EpisodeSelector(object):
    def __init__(self, data_collector, filter_candidates, percentile):
        super(EpisodeSelector, self).__init__()
        self.data_collector = data_collector
        self.filter_candidates = filter_candidates
        self.percentile = percentile

    def get_selected(self):
        generator = self.data_collector.iterate_episodes()
        while True:
            buffer = []
            for _ in range(self.filter_candidates):
                buffer.append(next(generator))
            rewards = np.array(list(map(lambda x: x.total_reward, buffer)))
            borderline = np.percentile(rewards, self.percentile)
            yield [episode for episode, reward in zip(buffer, rewards) if reward >= borderline], np.mean(rewards)


class DataIterator(object):
    def __init__(self, episode_selector):
        super(DataIterator, self).__init__()
        self.episode_selector = episode_selector

    def get_samples(self):
        for elite, reward in self.episode_selector.get_selected():
            states, actions = [], []
            for episode in elite:
                for step in episode.history:
                    states.append(step.state)
                    actions.append(step.action)
            X, y = np.array(states, dtype=np.float32, copy=False), np.array(actions, dtype=np.int32, copy=False)
            yield X, y, reward

main.py:

# main.py

import common
import gym

import torch
import torch.optim as optim
import torch.nn.functional as F


DEVICE = torch.device("cuda")
N_CANDIDATES = 256
PERCENTILE = 95


if __name__ == "__main__":
    env = gym.make("CartPole-v1")

    net = common.LinearNetwork(env.observation_space.shape[0], 128, env.action_space.n).to(DEVICE)
    opt = optim.Adam(net.parameters(), lr=1e-2)

    agent = common.PolicyAgent(net, env, DEVICE)
    collector = common.DataCollector(agent, env)
    selector = common.EpisodeSelector(collector, N_CANDIDATES, PERCENTILE)
    iterator = common.DataIterator(selector)

    for X_batch, y_batch, reward in iterator.get_samples():
        X_batch_v, y_batch_t = torch.FloatTensor(X_batch).to(DEVICE), torch.LongTensor(y_batch).to(DEVICE)

        opt.zero_grad()
        y_pred_v = net(X_batch_v)
        loss_v = F.cross_entropy(y_pred_v, y_batch_t)
        loss_v.backward()
        opt.step()
        print(y_batch_t, reward)

 

1 odpowiedź

0 głosów
odpowiedź 21 lipca 2021 przez Fut_techs Początkujący (420 p.)
Hej, wyślij co pokazuje terminal podczas errora by nieco bardziej nakreślić sytuację. Pozdrawiam.
komentarz 21 lipca 2021 przez Aqua 4 Gaduła (3,220 p.)
Chodzi o to, że wszystko wydaje się działać, ale nie ma żadnego postępu w trenowaniu.

Podobne pytania

–1 głos
0 odpowiedzi 195 wizyt
0 głosów
0 odpowiedzi 1,160 wizyt

92,662 zapytań

141,557 odpowiedzi

320,002 komentarzy

62,029 pasjonatów

Motyw:

Akcja Pajacyk

Pajacyk od wielu lat dożywia dzieci. Pomóż klikając w zielony brzuszek na stronie. Dziękujemy! ♡

Oto polecana książka warta uwagi.
Pełną listę książek znajdziesz tutaj.

Akademia Sekuraka

Niedawno wystartował dodruk tej świetnej, rozchwytywanej książki (około 940 stron). Mamy dla Was kod: pasja (wpiszcie go w koszyku), dzięki któremu otrzymujemy 10% zniżki - dziękujemy zaprzyjaźnionej ekipie Sekuraka za taki bonus dla Pasjonatów! Książka to pierwszy tom z serii o ITsec, który łagodnie wprowadzi w świat bezpieczeństwa IT każdą osobę - warto, polecamy!

...