Commit cf22f292 authored by Avisek Naug's avatar Avisek Naug 🎨
Browse files

fix how agent weights are saved

parent 4e615160
......@@ -40,7 +40,7 @@ from numpy.random import randn
from rl.agents import DDPGAgent
from keras.models import load_model
from agent import get_agent, train_agent
from agent import get_agent, train_agent, test_agent
from HVAC_Environment import Env
from modelretrain import retrain, createmodel
from ah_api import weeklyrelearndata
......@@ -326,6 +326,7 @@ def learn_control(source: str, save_to: str, relearn_window: int, duration: int,
try:
# get the new traindata
log.info('RELEARN: Fetching relearning data.')
relearndf = weeklyrelearndata(solardatapath=source, weeks=relearn_window)
relearndf.to_pickle('relearn.pkl')
......@@ -351,9 +352,31 @@ def learn_control(source: str, save_to: str, relearn_window: int, duration: int,
os.remove('relearn.pkl')
# Do not reinitialize the agent, instead use a COPY of the existing agent file
re_agent.load_weights('agent_weights.h5f')
re_agent.load_weights(save_to)
log.info('RELEARN: Training new control agent.')
train_agent(agent=re_agent, env=re_env, steps=duration, dest=save_to)
_ , re_agent = train_agent(agent=re_agent, env=re_env, steps=duration, dest=save_to)
# testing and comparing
re_agent.save_weights('latestagent_weights.h5f', overwrite=True)
re_env.testenv()
perfmetric_best = test_agent(re_agent, re_env, weights=save_to, actions=[])
re_env.testenv()
perfmetric_latest = test_agent(re_agent, re_env, weights='latestagent_weights.h5f', actions=[])
episode_rwd_best = sum(perfmetric_best.metrics[0]['reward'])
episode_rwd_latest = sum(perfmetric_latest.metrics[0]['reward'])
if episode_rwd_latest >= episode_rwd_best:
log.info('RELEARN: Best control policy weights changed.')
re_agent.save_weights(save_to, overwrite=True)
with open('.lastreward', 'w') as f:
f.write(str(episode_rwd_latest))
else:
log.info('RELEARN: Best control policy weights remain unchanged.')
with open('.lastreward', 'w') as f:
f.write(str(episode_rwd_latest))
log.info('RELEARN: Control policy adapted.')
signal_reload.set()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment