Commit af9e4ed4 authored by hazrmard's avatar hazrmard
Browse files

state as of AAAI2020 submission

parent 069f0237
This diff is collapsed.
......@@ -88,11 +88,11 @@
```
%% Cell type:code id: tags:
``` python
plot_tanks(env_)
plot_tanks(env_, columns=n_tanks, single_size=(3,3))
```
%% Cell type:markdown id: tags:
### System model
......@@ -712,10 +712,24 @@
plt.tight_layout()
fname = pth + env_name + '_' + 'divergence_comparison'
plt.savefig(fname + '.png')
```
%% Cell type:code id: tags:
``` python
# plotting env actions
i=0
for e, p in tqdm(zip(comp_envs, comp), total=len(comp), leave=False):
e.seed(SEED)
a = PPO(e, **ppo_params)
a.policy.load_state_dict(p)
plot_tanks(e, a, columns=n_tanks, single_size=(3,1.2), legend=(i == len(comp) - 1))
plt.show()
i += 1
```
%% Cell type:markdown id: tags:
## $\Delta \theta$ Approximations
%% Cell type:code id: tags:
......
This diff is collapsed.
# Fault Tolerant Control Using Reinforcement Learning
Using reinforcement learning to control hybrid systems under degradation. This repository contains code for the following papers:
1. Ahmed, I., Quiñones-Grueiro, M. and Biswas, G. (2020). Fault-Tolerant Control of Degrading Systems with On-Policy Reinforcement Learning. IFAC-PapersOnLine (under review).
Using reinforcement learning to control hybrid systems under degradation.
## Project structure
* `tanks.py`: Definitions of the fuel tanks model and OpenAI `gym` environment classes for use in reinforcement learning.
* `utils.py`: Some function used in the notebook for data transformation, not relevant to theory.
* `plotting.py`: Functions for plotting graphs.
* `envs/`: Directory containing environment files for development/deployment with/without GPU packages.
* `systems/`: Various reinforcement learning environments for testing.
* `python-envs/`: Directory containing environment files for development/deployment with/without GPU packages.
* Notebooks:
* E-MAML-*: Code pertaining to the AAAI 2020 submission on enhanced-meta learning.
## Usage
......@@ -29,15 +27,13 @@ This repository depends on [Anaconda](https://docs.conda.io/en/latest/miniconda.
2. Install python dependencies
Anaconda environment files are located in the `envs/` directory. Files with suffix `_cpu` install libraries without GPU acceleration. Files with prefix `dev` do not install a couple of packages that I authored. Instead those packages should be placed in the same directory as this repository.
Anaconda environment files are located in the `python-envs/` directory. Files with suffix `_cpu` install libraries without GPU acceleration. Files with prefix `dev` do not install a couple of packages that I authored. Instead those packages should be placed in the same directory as this repository.
```
conda env create -f environment.yml # GPU support for pyTorch/tensorflow
conda env create -f environment_cpu.yml
# The dev_*.yml files assume other packages written by the author are already
# The dev_*.yml environment files assume other packages written by the author are already
# in PYTHONPATH. In the notebooks their paths are manually added.
cd python-envs
conda env create -f dev.yml # GPU support for pyTorch/tensorflow
conda env create -f dev_cpu.yml
```
......
......@@ -109,7 +109,7 @@ class CartPoleDataEnv(CartPoleEnv):
def plot_cartpole(env, agent=None, state0=None):
def plot_cartpole(env, agent=None, state0=None, maxlen=500, legend=True):
if agent=='left':
actor = lambda s: 0
elif agent=='right':
......@@ -128,18 +128,25 @@ def plot_cartpole(env, agent=None, state0=None):
done = False
states = [state]
actions = []
while not done:
t = 0
while not done and t < maxlen:
action = actor(state)
actions.append(action)
state, _, done, _ = env.step(action)
states.append(state)
t += 1
states = np.asarray(states)
actions = np.asarray(actions)
x, theta = states[:, 0], states[:, 2]
xline = plt.plot(x, label='X')[0]
thetaline = plt.plot(theta, label='Angle /rad')[0]
plt.legend()
xline = plt.plot(x, 'b-', label='X')[0]
plt.ylabel('X')
plt.ylim(bottom=-env.x_threshold, top=env.x_threshold)
plt.twinx()
thetaline = plt.plot(theta, 'g:', label='Angle /rad')[0]
plt.ylabel('Angle /rad')
plt.ylim(bottom=-env.theta_threshold_radians, top=env.theta_threshold_radians)
# plt.legend()
# pylint: disable=no-member
im = plt.imshow(actions.reshape(1, -1), aspect='auto', alpha=0.3,
extent=(*plt.xlim(), *plt.ylim()), origin='lower',
......@@ -147,5 +154,7 @@ def plot_cartpole(env, agent=None, state0=None):
colors = [im.cmap(im.norm(value)) for value in (0, 1)]
patches = [mpatches.Patch(color=colors[0], label="Left", alpha=0.3),
mpatches.Patch(color=colors[1], label="Right", alpha=0.3)]
plt.legend(handles=[xline, thetaline] + patches)
plt.grid(True)
\ No newline at end of file
plt.grid(True)
if legend:
plt.legend(handles=[xline, thetaline] + patches)
return [xline, thetaline] + patches, ('X', 'Angle /rad', 'Left', 'Right')
\ No newline at end of file
......@@ -276,7 +276,7 @@ class TanksPhysicalEnv(gym.Env):
median_supply = 0.5 * x_next[self.median_idx] * self.odd
left_supply = sum(x_next[self.left_idx] + median_supply)
right_supply = sum(x_next[self.right_idx] + median_supply)
done = (left_demand > left_supply) or (right_demand > right_supply)
done = (left_demand > left_supply) or (right_demand > right_supply) or self.t > self.episode_length
reward = self.reward(self.t, self.x, action, x_next, done)
self.x = x_next
......@@ -460,7 +460,7 @@ class TanksDataRecurrentEnv(TanksDataEnv):
def plot_tanks(env, agent=None, plot='both'):
def plot_tanks(env, agent=None, plot='both', columns=2, single_size=(6,4), legend=True):
n_tanks = len(env.tanks.heights)
if agent is not None:
x, u, done = [], [], False
......@@ -493,10 +493,13 @@ def plot_tanks(env, agent=None, plot='both'):
for i in range(len(u_open)):
x_open[i] = env.step(u_open[i])[0]
plt.figure(figsize=(12, 12))
width, height = single_size
rows = n_tanks // columns + (1 if n_tanks % columns else 0)
figsize = (columns * width, rows * height)
plt.figure(figsize=figsize)
patches = None
for i in range(n_tanks):
plt.subplot(n_tanks // 2, 2, i+1)
plt.subplot(rows, columns, i+1)
plt.ylim(0, 1.05 * max(env.tanks.heights))
if plot in ('open', 'both'):
plt.plot(x_open[:, i], '--', label='Open' if i==n_tanks-1 else None)
......@@ -512,8 +515,12 @@ def plot_tanks(env, agent=None, plot='both'):
patches = [mpatches.Patch(color=colors[0], label="Closed", alpha=0.3),
mpatches.Patch(color=colors[1], label="Opened", alpha=0.3),]
plt.plot(x[:, i], '-', label='RL' if i==n_tanks-1 else None)
if i !=0 and i % columns !=0:
plt.gca().set_yticklabels([])
plt.ylabel('Tank ' + str(i + 1))
if i >= 4: plt.xlabel('Time /s')
if (i == n_tanks-2) and patches is not None: plt.legend(handles=patches)
if i==n_tanks-1: plt.legend()
plt.grid(True)
\ No newline at end of file
if i >= columns * (rows-1) and legend: plt.xlabel('Time /s')
if (i == n_tanks-2) and patches is not None and legend: plt.legend(handles=patches)
if i==n_tanks-1 and legend: plt.legend()
plt.grid(True)
# plt.tight_layout()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment