dummyLSTMseq.m 1.77 KB
Newer Older
1
2
% Training an LSTM network to predict the derivative
% of a sequence.
hazrmard's avatar
hazrmard committed
3
% 
4
5
% Input:  1,2,5,3
% Output: 1,1,3,-2
hazrmard's avatar
hazrmard committed
6
N = 1000;
7
8
9
10
11
12
13
14
15
16
17
18
NTest = N;
seqLen = N;
range = 100;

rng(0);
numbers = randi([0, range], 1, N+NTest);
gradients = gradient(numbers);

seq = reshape(numbers(1:N), seqLen, [])';
seqY = reshape(gradients(1:N), seqLen, [])';
testSeq = reshape(numbers(N+1:end), seqLen, [])';
testSeqY = reshape(gradients(N+1:end), seqLen, [])';
hazrmard's avatar
hazrmard committed
19
20
21
22
23
24
25
26
% The input variables are normalized to have 0 mean
% and 1 variance.
mu = mean(seq(:));
sig = std(seq(:));
stdSeq = (seq - mu) / sig;
stdTest = (testSeq - mu) / sig;

% Preparing input data as a cell array of row vectors.
27
XTrain = mat2cell(stdSeq, ones(1, N/seqLen));
hazrmard's avatar
hazrmard committed
28
29
30
31
XTest = mat2cell(stdTest, ones(1, NTest/seqLen));

% For format of the response, see
% https://www.mathworks.com/help/deeplearning/ref/trainnetwork.html
32
33
YTrain = mat2cell(seqY, ones(1, N/seqLen));
YTest = mat2cell(testSeqY, ones(1, NTest/seqLen));
hazrmard's avatar
hazrmard committed
34
35
36
37
38
39
40

% Defining an LSTM network. An LSTM layer takes as
% input `numHiddenUnits` which is the number of time
% steps that an LSTM cell "memorizes" i.e. the window
% over which the LSTM layer operates.
layers = [...
    sequenceInputLayer(1)
41
42
    lstmLayer(1,'OutputMode','sequence')
    lstmLayer(1,'OutputMode','sequence')
hazrmard's avatar
hazrmard committed
43
44
45
46
    fullyConnectedLayer(1)
    regressionLayer];

options = trainingOptions('adam', ...
47
    'MaxEpochs',100, ...
hazrmard's avatar
hazrmard committed
48
49
    'MiniBatchSize', 5, ...
    'GradientThreshold',1, ...
50
    'InitialLearnRate',1, ...
hazrmard's avatar
hazrmard committed
51
    'LearnRateSchedule','piecewise', ...
52
53
    'LearnRateDropPeriod',50, ...
    'LearnRateDropFactor',0.75, ...
hazrmard's avatar
hazrmard committed
54
55
56
57
58
    'Verbose',0, ...
    'Plots','training-progress');

net = trainNetwork(XTrain, YTrain, layers, options);

59
60
61
YPred = int32(cell2mat(predict(net, XTest)));
plot(1:NTest, YPred, 1:NTest, gradients(N+1:end));
legend('Predicted', 'Actual');