Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib|Custom Policy]Custom Policy Implementation in Reinforcement Learning #49334

Open
XinPoi opened this issue Dec 18, 2024 · 0 comments
Open
Labels
bug Something that is supposed to be working; but isn't rllib RLlib related issues triage Needs triage (eg: priority, bug/not-bug, and owning component)

Comments

@XinPoi
Copy link

XinPoi commented Dec 18, 2024

What happened + What you expected to happen

I am a beginner in reinforcement learning, currently trying to implement my own reinforcement learning algorithm using a custom policy framework. During the process, I noticed that the default data in samples does not meet my needs. Therefore, I followed the example in the official documentation to return additional information for network training through the compute_actions method.

However, when I check the samples parameter in the learn_on_batch method, I do not see the corresponding data being passed. Could anyone help me understand why this is happening?

Versions / Dependencies

ray 2.0.0
python 3.8.19

Reproduction script

import math
import torch
import numpy as np
from torch import nn
from torch.nn import functional as F
from torch.distributions import Beta, Normal
from ray.rllib.policy.torch_policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch

class BetaActor(nn.Module):
    def __init__(self, state_dim, action_dim, net_width):
        super(BetaActor, self).__init__()

        self.l1 = nn.Linear(state_dim, net_width)
        self.l2 = nn.Linear(net_width, net_width)
        self.alpha_head = nn.Linear(net_width, action_dim)
        self.beta_head = nn.Linear(net_width, action_dim)

    def forward(self, state):
        a = torch.tanh(self.l1(state))
        a = torch.tanh(self.l2(a))

        # a = F.relu(self.l1(state))
        # a = F.relu(self.l2(a))

        alpha = F.softplus(self.alpha_head(a)) + 1.0 
        beta = F.softplus(self.beta_head(a)) + 1.0

        return alpha, beta

    def get_dist(self, state):
        alpha, beta = self.forward(state)
        dist = Beta(alpha, beta)
        return dist

    def deterministic_act(self, state):
        alpha, beta = self.forward(state)
        mode = alpha / (alpha + beta)
        return mode

class Critic(nn.Module):
    def __init__(self, state_dim, net_width):
        super(Critic, self).__init__()

        self.C1 = nn.Linear(state_dim, net_width)
        self.C2 = nn.Linear(net_width, net_width)
        self.C3 = nn.Linear(net_width, 1)

    def forward(self, state):
        v = torch.tanh(self.C1(state))
        v = torch.tanh(self.C2(v))
        v = self.C3(v)
        return v

class PPOContinuous(Policy):
    def __init__(self, observation_space, action_space, config):
        super().__init__(observation_space, action_space, config)

        self.state_dim = observation_space.shape[0]
        self.action_dim = action_space.shape[0]

        self.net_width = 256
        # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = torch.device("cpu")
        self.entropy_coef = 1e-3
        self.entropy_coef_decay = 0.99
        self.gamma = 0.99
        self.lambd = 0.95
        self.actor_lr = 1e-4
        self.critic_lr = 1e-3
        self.Distribution = ['Beta', 'Gauss'][0]
        self.a_optim_batch_size = 512
        self.c_optim_batch_size = 512
        self.K_epochs = 10
        self.clip_rate = 0.2

        # Build Actor
        if self.Distribution == 'Beta':
            self.actor = BetaActor(self.state_dim, self.action_dim, self.net_width).to(self.device)
        elif self.Distribution == 'Gauss':
            self.actor = GaussianActor(self.state_dim, self.action_dim, self.net_width).to(self.device)
        else:
            print('Dist Error')
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.actor_lr)

        # Build Critic
        self.critic = Critic(self.state_dim, self.net_width).to(self.device)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.critic_lr)

    def compute_actions(self, obs, state_batches, **kwargs):
        deterministic = False
        with torch.no_grad():
            obs = torch.FloatTensor(obs.reshape(1, -1)).to(self.device)
            if deterministic:
                # only used when evaluate the policy.Making the performance more stable
                a = self.actor.deterministic_act(obs)
                return a.cpu().numpy()[0], [], {}
            else:
                # only used when interact with the env
                dist = self.actor.get_dist(obs)
                a = dist.sample()
                a = torch.clamp(a, 0, 1)
                action_logprob = dist.log_prob(a).cpu().numpy()
                # return a.cpu().numpy(), [], {SampleBatch.ACTION_LOGP: action_logprob}

                action_info_batch = {
                    "some_value": ["foo" for _ in obs],
                    "other_value": [12345 for _ in obs],
                }
                return a.cpu().numpy(), [], action_info_batch

    def learn_on_batch(self, samples: SampleBatch):
        assert "other_value" in samples.keys()
        assert "some_value" in samples.keys()
        self.entropy_coef *= self.entropy_coef_decay

        '''Prepare PyTorch data from Numpy data'''
        s = torch.from_numpy(samples[SampleBatch.CUR_OBS]).to(self.device)
        a = torch.from_numpy(samples[SampleBatch.ACTIONS]).to(self.device)
        r = torch.from_numpy(samples[SampleBatch.REWARDS]).to(self.device)
        s_next = torch.from_numpy(samples[SampleBatch.NEXT_OBS]).to(self.device)
        done = torch.from_numpy(samples[SampleBatch.DONES]).to(self.device)
        logprob_a = torch.from_numpy(samples[SampleBatch.ACTION_LOGP]).to(self.device)

        ''' Use TD+GAE+LongTrajectory to compute Advantage and TD target'''
        adv = []
        gae = 0
        with torch.no_grad():  # adv and v_target have no gradient
            vs = self.critic(s)
            vs_ = self.critic(s_next)
            deltas = r + self.gamma * (~done) * vs_ - vs
            for delta, d in zip(reversed(deltas.cpu().flatten().numpy()), reversed(done.cpu().flatten().numpy())):
                gae = delta + self.gamma * self.lambd * gae * (1.0 - d)
                adv.insert(0, gae)
            adv = torch.tensor(adv, dtype=torch.float).view(-1, 1).to(self.device)
            # adv normalize
            td_target = adv + vs
            adv = (adv - adv.mean()) / (adv.std() + 1e-4)  # sometimes helps

        """Slice long trajectopy into short trajectory and perform mini-batch PPO update"""
        a_optim_iter_num = int(math.ceil(s.shape[0] / self.a_optim_batch_size))
        c_optim_iter_num = int(math.ceil(s.shape[0] / self.c_optim_batch_size))

        for i in range(self.K_epochs):

            # Shuffle the trajectory, Good for training
            perm = np.arange(s.shape[0])
            np.random.shuffle(perm)
            perm = torch.LongTensor(perm).to(self.device)
            s, a, td_target, adv, logprob_a = \
                s[perm].clone(), a[perm].clone(), td_target[perm].clone(), adv[perm].clone(), logprob_a[
                    perm].clone()

            '''update the actor'''
            for i in range(a_optim_iter_num):
                index = slice(i * self.a_optim_batch_size, min((i + 1) * self.a_optim_batch_size, s.shape[0]))
                distribution = self.actor.get_dist(s[index])
                dist_entropy = distribution.entropy().sum(1, keepdim=True)
                logprob_a_now = distribution.log_prob(a[index])
                ratio = torch.exp(logprob_a_now.sum(1, keepdim=True) - logprob_a[index].sum(1,
                                                                                            keepdim=True))  # a/b == exp(log(a)-log(b))

                surr1 = ratio * adv[index]
                surr2 = torch.clamp(ratio, 1 - self.clip_rate, 1 + self.clip_rate) * adv[index]
                a_loss = -torch.min(surr1, surr2) - self.entropy_coef * dist_entropy

                self.actor_optimizer.zero_grad()
                a_loss.mean().backward()
                # clip grad
                torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 40)
                self.actor_optimizer.step()

            '''update the critic'''
            for i in range(c_optim_iter_num):
                index = slice(i * self.c_optim_batch_size, min((i + 1) * self.c_optim_batch_size, s.shape[0]))
                c_loss = (self.critic(s[index]) - td_target[index]).pow(2).mean()

                self.critic_optimizer.zero_grad()
                c_loss.backward()
                self.critic_optimizer.step()

        return {"actor_loss": a_loss.item(), "critic_loss":c_loss.item()}

    def get_weights(self):
        return {
            "actor": self.actor.state_dict(),
            "critic": self.critic.state_dict(),
            "actor optimizer": self.actor_optimizer.state_dict(),
            "critic optimizer": self.critic_optimizer.state_dict(),
        }

    def set_weights(self, weights):
        self.actor.load_state_dict(weights["actor"])
        self.critic.load_state_dict(weights["critic"])
        self.critic.load_state_dict(weights["critic"])

from ray.rllib.algorithms import Algorithm

class PPOContinuousTrainer(Algorithm):
    def get_default_policy_class(self, config):
        print("policy create")
        return PPOContinuous

config = {
    "env": "Pendulum-v1", 
    "num_workers": 0, 
    "lr": 0.001, 
    "framework": "torch", 
    "simple_optimizer": True,
}

trainer = PPOContinuousTrainer(config=config)

for i in range(100):
    result = trainer.train()
    print(f"Iteration {i}: Loss = {result['info']['learner']['default_policy']['loss']}")

Issue Severity

High: It blocks me from completing my task.

@XinPoi XinPoi added bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Dec 18, 2024
@jcotant1 jcotant1 added the rllib RLlib related issues label Dec 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't rllib RLlib related issues triage Needs triage (eg: priority, bug/not-bug, and owning component)
Projects
None yet
Development

No branches or pull requests

2 participants