Commit cf22f292 authored by Avisek Naug's avatar Avisek Naug 🎨
Browse files

fix how agent weights are saved

parent 4e615160
......@@ -40,7 +40,7 @@ from numpy.random import randn
from rl.agents import DDPGAgent
from keras.models import load_model
from agent import get_agent, train_agent
from agent import get_agent, train_agent, test_agent
from HVAC_Environment import Env
from modelretrain import retrain, createmodel
from ah_api import weeklyrelearndata
......@@ -326,6 +326,7 @@ def learn_control(source: str, save_to: str, relearn_window: int, duration: int,
# get the new traindata'RELEARN: Fetching relearning data.')
relearndf = weeklyrelearndata(solardatapath=source, weeks=relearn_window)
......@@ -351,9 +352,31 @@ def learn_control(source: str, save_to: str, relearn_window: int, duration: int,
# Do not reinitialize the agent, instead use a COPY of the existing agent file
re_agent.load_weights(save_to)'RELEARN: Training new control agent.')
train_agent(agent=re_agent, env=re_env, steps=duration, dest=save_to)
_ , re_agent = train_agent(agent=re_agent, env=re_env, steps=duration, dest=save_to)
# testing and comparing
re_agent.save_weights('latestagent_weights.h5f', overwrite=True)
perfmetric_best = test_agent(re_agent, re_env, weights=save_to, actions=[])
perfmetric_latest = test_agent(re_agent, re_env, weights='latestagent_weights.h5f', actions=[])
episode_rwd_best = sum(perfmetric_best.metrics[0]['reward'])
episode_rwd_latest = sum(perfmetric_latest.metrics[0]['reward'])
if episode_rwd_latest >= episode_rwd_best:'RELEARN: Best control policy weights changed.')
re_agent.save_weights(save_to, overwrite=True)
with open('.lastreward', 'w') as f:
else:'RELEARN: Best control policy weights remain unchanged.')
with open('.lastreward', 'w') as f:
f.write(str(episode_rwd_latest))'RELEARN: Control policy adapted.')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment