• Najnowsze pytania
  • Bez odpowiedzi
  • Zadaj pytanie
  • Kategorie
  • Tagi
  • Zdobyte punkty
  • Ekipa ninja
  • IRC
  • FAQ
  • Regulamin
  • Książki warte uwagi

Gdzie jest błąd? - Program do uczenia maszynowego

Object Storage Arubacloud
0 głosów
311 wizyt
pytanie zadane 3 lipca 2021 w Python przez Aqua 4 Gaduła (3,220 p.)

Od pewnego czasu próbuję zaimplementować algorytm uczenie przez wzmacnianie "CrossEntropy" (wyjaśniony tutaj). Problem polega na tym, że agent nie uczy się niczego, a po kilku iteracjach sieć neuronowa zaczyna zwracać zawsze tą samą wartość. Debuguję ten program od kilku godzin, proszę pomóżcie.

Kod:

commom.py:

# common.py

from collections import namedtuple
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F


EpisodeStep = namedtuple("EpisodeStep", field_names=("state", "action", "reward", "next_state", "done", "info"))
Episode = namedtuple("Episode", field_names=("total_reward", "history"))


class LinearNetwork(nn.Module):
    def __init__(self, input_shape, hidden_size, output_shape):
        super(LinearNetwork, self).__init__()
        self.nn = nn.Sequential(
            nn.Linear(in_features=input_shape, out_features=hidden_size), nn.ELU(),
            nn.Linear(in_features=hidden_size, out_features=output_shape))

    def forward(self, x):
        return self.nn(x)


class PolicyAgent(object):
    def __init__(self, net, env, device):
        super(PolicyAgent, self).__init__()
        self.net = net
        self.env = env
        self.device = device

    def __call__(self, observation):
        probs = F.softmax(self.net(
            torch.FloatTensor(observation[None, ...]).to(self.device))[0], dim=0).data.cpu().numpy()
        action = np.random.choice(probs.shape[0], p=probs)
        next_state, rewards, done, info = self.env.step(action)
        return EpisodeStep(state=observation, action=action, reward=rewards,
                           next_state=next_state, done=done, info=info)


class DataCollector(object):
    def __init__(self, agent, env):
        self.agent = agent
        self.env = env

    def iterate_steps(self):
        while True:
            obs = self.env.reset()
            done = False
            while not done:
                step = self.agent(obs)
                obs, done = step.next_state, step.done
                yield step

    def iterate_episodes(self):
        history, total_reward = [], 0.
        for step in self.iterate_steps():
            history.append(step)
            total_reward += step.reward
            if step.done:
                yield Episode(total_reward=total_reward, history=history)
                history.clear()
                total_reward = 0.


class EpisodeSelector(object):
    def __init__(self, data_collector, filter_candidates, percentile):
        super(EpisodeSelector, self).__init__()
        self.data_collector = data_collector
        self.filter_candidates = filter_candidates
        self.percentile = percentile

    def get_selected(self):
        generator = self.data_collector.iterate_episodes()
        while True:
            buffer = []
            for _ in range(self.filter_candidates):
                buffer.append(next(generator))
            rewards = np.array(list(map(lambda x: x.total_reward, buffer)))
            borderline = np.percentile(rewards, self.percentile)
            yield [episode for episode, reward in zip(buffer, rewards) if reward >= borderline], np.mean(rewards)


class DataIterator(object):
    def __init__(self, episode_selector):
        super(DataIterator, self).__init__()
        self.episode_selector = episode_selector

    def get_samples(self):
        for elite, reward in self.episode_selector.get_selected():
            states, actions = [], []
            for episode in elite:
                for step in episode.history:
                    states.append(step.state)
                    actions.append(step.action)
            X, y = np.array(states, dtype=np.float32, copy=False), np.array(actions, dtype=np.int32, copy=False)
            yield X, y, reward

main.py:

# main.py

import common
import gym

import torch
import torch.optim as optim
import torch.nn.functional as F


DEVICE = torch.device("cuda")
N_CANDIDATES = 256
PERCENTILE = 95


if __name__ == "__main__":
    env = gym.make("CartPole-v1")

    net = common.LinearNetwork(env.observation_space.shape[0], 128, env.action_space.n).to(DEVICE)
    opt = optim.Adam(net.parameters(), lr=1e-2)

    agent = common.PolicyAgent(net, env, DEVICE)
    collector = common.DataCollector(agent, env)
    selector = common.EpisodeSelector(collector, N_CANDIDATES, PERCENTILE)
    iterator = common.DataIterator(selector)

    for X_batch, y_batch, reward in iterator.get_samples():
        X_batch_v, y_batch_t = torch.FloatTensor(X_batch).to(DEVICE), torch.LongTensor(y_batch).to(DEVICE)

        opt.zero_grad()
        y_pred_v = net(X_batch_v)
        loss_v = F.cross_entropy(y_pred_v, y_batch_t)
        loss_v.backward()
        opt.step()
        print(y_batch_t, reward)

 

1 odpowiedź

0 głosów
odpowiedź 21 lipca 2021 przez Fut_techs Początkujący (420 p.)
Hej, wyślij co pokazuje terminal podczas errora by nieco bardziej nakreślić sytuację. Pozdrawiam.
komentarz 21 lipca 2021 przez Aqua 4 Gaduła (3,220 p.)
Chodzi o to, że wszystko wydaje się działać, ale nie ma żadnego postępu w trenowaniu.

Podobne pytania

–1 głos
0 odpowiedzi 192 wizyt
0 głosów
0 odpowiedzi 1,153 wizyt

92,576 zapytań

141,426 odpowiedzi

319,652 komentarzy

61,961 pasjonatów

Motyw:

Akcja Pajacyk

Pajacyk od wielu lat dożywia dzieci. Pomóż klikając w zielony brzuszek na stronie. Dziękujemy! ♡

Oto polecana książka warta uwagi.
Pełną listę książek znajdziesz tutaj.

Akademia Sekuraka

Kolejna edycja największej imprezy hakerskiej w Polsce, czyli Mega Sekurak Hacking Party odbędzie się już 20 maja 2024r. Z tej okazji mamy dla Was kod: pasjamshp - jeżeli wpiszecie go w koszyku, to wówczas otrzymacie 40% zniżki na bilet w wersji standard!

Więcej informacji na temat imprezy znajdziecie tutaj. Dziękujemy ekipie Sekuraka za taką fajną zniżkę dla wszystkich Pasjonatów!

Akademia Sekuraka

Niedawno wystartował dodruk tej świetnej, rozchwytywanej książki (około 940 stron). Mamy dla Was kod: pasja (wpiszcie go w koszyku), dzięki któremu otrzymujemy 10% zniżki - dziękujemy zaprzyjaźnionej ekipie Sekuraka za taki bonus dla Pasjonatów! Książka to pierwszy tom z serii o ITsec, który łagodnie wprowadzi w świat bezpieczeństwa IT każdą osobę - warto, polecamy!

...