rlx is a Deep RL library written on top of PyTorch & built for educational and research purpose. Majority of the libraries/codebases for Deep RL are geared more towards reproduction of state-of-the-art algorithms on very specific tasks (e.g. Atari games etc.), but rlx is NOT. It is supposed to be more expressive and modular. Rather than making RL algorithms as black-boxes, rlx adopts an API that tries to expose more granular operation to the users which makes writing new algorithms easier. It is also useful for implementing task specific engineering into a known algorithm.
Concisely, rlx is supposed to
Here's a basic example of PPO (with clipping) implementation with rlx
base_rollout = agent(policy).episode(horizon) # sample an episode as a 'Rollout' object base_rewards, base_logprobs = base_rollout.rewards, base_rollout.logprobs # 'rewards' and 'logprobs' for all timesteps base_returns = base_rollout.mc_returns() # Monte-carlo estimates of 'returns' for _ in range(k_epochs): rollout = agent(policy).evaluate(base_rollout) # 'evaluate' an episode against a policy and get a new 'Rollout' object logprobs, entropy = rollout.logprobs, rollout.entropy # get 'logprobs' and 'entropy' for all timesteps values, = rollout.others # .. also 'value' estimates ratios = (logprobs - base_logprobs.detach()).exp() advantage = base_returns - values policyloss = - torch.min(ratios, torch.clamp(ratios, 1 - clip, 1 + clip)) * advantage.detach() valueloss = advantage.pow(2) loss = policyloss.sum() + 0.5 * valueloss.sum() - entropy.sum() * 0.01 agent.zero_grad() loss.backward() agent.step()
Visit the README for further details.
Don't forget to tag @dasayan05 in your comment, otherwise they may not be notified.