• Najnowsze pytania
  • Bez odpowiedzi
  • Zadaj pytanie
  • Kategorie
  • Tagi
  • Zdobyte punkty
  • Ekipa ninja
  • IRC
  • FAQ
  • Regulamin
  • Książki warte uwagi

Gdzie jest błąd? - Program do uczenia maszynowego

VPS Starter Arubacloud
0 głosów
306 wizyt
pytanie zadane 3 lipca 2021 w Python przez Aqua 4 Gaduła (3,220 p.)

Od pewnego czasu próbuję zaimplementować algorytm uczenie przez wzmacnianie "CrossEntropy" (wyjaśniony tutaj). Problem polega na tym, że agent nie uczy się niczego, a po kilku iteracjach sieć neuronowa zaczyna zwracać zawsze tą samą wartość. Debuguję ten program od kilku godzin, proszę pomóżcie.

Kod:

commom.py:

# common.py

from collections import namedtuple
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F


EpisodeStep = namedtuple("EpisodeStep", field_names=("state", "action", "reward", "next_state", "done", "info"))
Episode = namedtuple("Episode", field_names=("total_reward", "history"))


class LinearNetwork(nn.Module):
    def __init__(self, input_shape, hidden_size, output_shape):
        super(LinearNetwork, self).__init__()
        self.nn = nn.Sequential(
            nn.Linear(in_features=input_shape, out_features=hidden_size), nn.ELU(),
            nn.Linear(in_features=hidden_size, out_features=output_shape))

    def forward(self, x):
        return self.nn(x)


class PolicyAgent(object):
    def __init__(self, net, env, device):
        super(PolicyAgent, self).__init__()
        self.net = net
        self.env = env
        self.device = device

    def __call__(self, observation):
        probs = F.softmax(self.net(
            torch.FloatTensor(observation[None, ...]).to(self.device))[0], dim=0).data.cpu().numpy()
        action = np.random.choice(probs.shape[0], p=probs)
        next_state, rewards, done, info = self.env.step(action)
        return EpisodeStep(state=observation, action=action, reward=rewards,
                           next_state=next_state, done=done, info=info)


class DataCollector(object):
    def __init__(self, agent, env):
        self.agent = agent
        self.env = env

    def iterate_steps(self):
        while True:
            obs = self.env.reset()
            done = False
            while not done:
                step = self.agent(obs)
                obs, done = step.next_state, step.done
                yield step

    def iterate_episodes(self):
        history, total_reward = [], 0.
        for step in self.iterate_steps():
            history.append(step)
            total_reward += step.reward
            if step.done:
                yield Episode(total_reward=total_reward, history=history)
                history.clear()
                total_reward = 0.


class EpisodeSelector(object):
    def __init__(self, data_collector, filter_candidates, percentile):
        super(EpisodeSelector, self).__init__()
        self.data_collector = data_collector
        self.filter_candidates = filter_candidates
        self.percentile = percentile

    def get_selected(self):
        generator = self.data_collector.iterate_episodes()
        while True:
            buffer = []
            for _ in range(self.filter_candidates):
                buffer.append(next(generator))
            rewards = np.array(list(map(lambda x: x.total_reward, buffer)))
            borderline = np.percentile(rewards, self.percentile)
            yield [episode for episode, reward in zip(buffer, rewards) if reward >= borderline], np.mean(rewards)


class DataIterator(object):
    def __init__(self, episode_selector):
        super(DataIterator, self).__init__()
        self.episode_selector = episode_selector

    def get_samples(self):
        for elite, reward in self.episode_selector.get_selected():
            states, actions = [], []
            for episode in elite:
                for step in episode.history:
                    states.append(step.state)
                    actions.append(step.action)
            X, y = np.array(states, dtype=np.float32, copy=False), np.array(actions, dtype=np.int32, copy=False)
            yield X, y, reward

main.py:

# main.py

import common
import gym

import torch
import torch.optim as optim
import torch.nn.functional as F


DEVICE = torch.device("cuda")
N_CANDIDATES = 256
PERCENTILE = 95


if __name__ == "__main__":
    env = gym.make("CartPole-v1")

    net = common.LinearNetwork(env.observation_space.shape[0], 128, env.action_space.n).to(DEVICE)
    opt = optim.Adam(net.parameters(), lr=1e-2)

    agent = common.PolicyAgent(net, env, DEVICE)
    collector = common.DataCollector(agent, env)
    selector = common.EpisodeSelector(collector, N_CANDIDATES, PERCENTILE)
    iterator = common.DataIterator(selector)

    for X_batch, y_batch, reward in iterator.get_samples():
        X_batch_v, y_batch_t = torch.FloatTensor(X_batch).to(DEVICE), torch.LongTensor(y_batch).to(DEVICE)

        opt.zero_grad()
        y_pred_v = net(X_batch_v)
        loss_v = F.cross_entropy(y_pred_v, y_batch_t)
        loss_v.backward()
        opt.step()
        print(y_batch_t, reward)

 

1 odpowiedź

0 głosów
odpowiedź 21 lipca 2021 przez Fut_techs Początkujący (420 p.)
Hej, wyślij co pokazuje terminal podczas errora by nieco bardziej nakreślić sytuację. Pozdrawiam.
komentarz 21 lipca 2021 przez Aqua 4 Gaduła (3,220 p.)
Chodzi o to, że wszystko wydaje się działać, ale nie ma żadnego postępu w trenowaniu.

Podobne pytania

–1 głos
0 odpowiedzi 191 wizyt
0 głosów
0 odpowiedzi 1,151 wizyt

92,451 zapytań

141,261 odpowiedzi

319,073 komentarzy

61,853 pasjonatów

Motyw:

Akcja Pajacyk

Pajacyk od wielu lat dożywia dzieci. Pomóż klikając w zielony brzuszek na stronie. Dziękujemy! ♡

Oto polecana książka warta uwagi.
Pełną listę książek znajdziesz tutaj.

Akademia Sekuraka

Akademia Sekuraka 2024 zapewnia dostęp do minimum 15 szkoleń online z bezpieczeństwa IT oraz dostęp także do materiałów z edycji Sekurak Academy z roku 2023!

Przy zakupie możecie skorzystać z kodu: pasja-akademia - użyjcie go w koszyku, a uzyskacie rabat -30% na bilety w wersji "Standard"! Więcej informacji na temat akademii 2024 znajdziecie tutaj. Dziękujemy ekipie Sekuraka za taką fajną zniżkę dla wszystkich Pasjonatów!

Akademia Sekuraka

Niedawno wystartował dodruk tej świetnej, rozchwytywanej książki (około 940 stron). Mamy dla Was kod: pasja (wpiszcie go w koszyku), dzięki któremu otrzymujemy 10% zniżki - dziękujemy zaprzyjaźnionej ekipie Sekuraka za taki bonus dla Pasjonatów! Książka to pierwszy tom z serii o ITsec, który łagodnie wprowadzi w świat bezpieczeństwa IT każdą osobę - warto, polecamy!

...