Hello! I have been trying to design a CNN for image analysis. The CNN is training on simulated images of size 132 x 132 x 6 (spatial, spatial, channel). The simulated images are computed using a bi-exponential equation of the form . In the CNN, the input images are forward passed through the network to generate four feature maps (, , , and ) which then are scaled and used to calculated the predicted image signals, . The predicted image signals are then compared to the input image signals S using the mean squared error loss function and the gradients are updated. The problem is the network is not learning. After some inspection I noticed that the gradients are all going to zero, however I'm not sure how to fix this problem. I have tried changing the learning rate, adam v. sgdm optimizers, and the mini-batch size, however I encounter the same problem. Any advice/feedback is greatly appreciated!
Also, I have removed parts of the code to make it as simple as possible for the time being, but will add in validation and testing loops.
% Image Parameters
rng(1);
imageSize = [132, 132];
bValue = [50 100 150 250 500 800]; % non-zero diffusion weightings
numbVal = length(bValue);
minDf = 0.0017;
maxDf = 0.107;
minf = 0.1;
maxf = 0.5;
minDs = 0.0003;
maxDs = 0.0017;
DfSim = minDf + (maxDf-minDf).*rand(10,1);
fSim = minf + (maxf-minf).*rand(10,1);
DsSim = minDs + (maxDs-minDs).*rand(10,1);
numIm = length(DfSim) * length(fSim) * length(DsSim); % number of 132 x 132 x 6 images
tissue = ones(imageSize);
bValue = reshape(bValue, [1,1,numbVal]); % Reshape bValue for matrix operation
% Prepare a directory to store the simulated images
outputDir = fullfile(tempdir, 'SimulatedDW-MRI');
if ~exist(outputDir, 'dir')
mkdir(outputDir);
end
% Initialize a table to store the image file paths and parameters
fprintf('Total simulated images: %d\n', numIm);
imageData = table('Size', [0 4],...
'VariableTypes', {'cell', 'double', 'double', 'double'},...
'VariableNames', {'imageFilePath', 'DfSim', 'fSim', 'DsSim'});
% Start the timer
tic;
% Loop through each combination of DfSim, fSim, and DsSim
imageIdx = 0;
S = zeros([imageSize length(bValue) numIm]);
for DfIdx = 1:length(DfSim)
for fIdx = 1:length(fSim)
for DsIdx = 1:length(DsSim)
imageIdx = imageIdx + 1;
% Calculate the diffusion signal for each b value for each channel
S(:,:,:,imageIdx) = tissue .* ((fSim(fIdx) .* exp(-bValue .* DfSim(DfIdx))) + ((1-fSim(fIdx)) .* exp(-bValue .* DsSim(DsIdx))));
% Track progress
fprintf('Processing image %d out of %d\n', imageIdx, numIm);
end
end
end
for imageIdx = 1:numIm
fileName = sprintf('%s/image%d.mat', outputDir, imageIdx); % Write the image to a .mat file
S_single = S(:,:,:,imageIdx);
save(fileName, 'S_single');
DfIdx = ceil(imageIdx / (length(fSim)*length(DsSim)));
fIdx = ceil((imageIdx - (DfIdx-1)*length(fSim)*length(DsSim)) / length(DsSim));
DsIdx = imageIdx - (DfIdx-1)*length(fSim)*length(DsSim) - (fIdx-1)*length(DsSim);
imageData(imageIdx, :) = {fileName, DfSim(DfIdx), fSim(fIdx), DsSim(DsIdx)};
fprintf('Saving image %d out of %d\n', imageIdx, numIm);
end
elapsedTime = toc;
fprintf('Computation time: %.2f seconds\n', elapsedTime);
%% Split data in training, validation, and testing sets
trainSplit = 0.8;
valSplit = 0.1;
testSplit = 0.1;
n = height(imageData);
idx = randperm(n);
trainIdx = idx(1:round(trainSplit*n));
valIdx = idx(round(trainSplit*n)+1:round((trainSplit+valSplit)*n));
testIdx = idx(round((trainSplit+valSplit)*n)+1:end);
imageDataTrain = imageData(trainIdx, :);
imageDataVal = imageData(valIdx, :);
imageDataTest = imageData(testIdx, :);
trainImds = fileDatastore(imageDataTrain.imageFilePath, ...
'ReadFcn' , @(filename) double(load(filename).S_single), ...
'FileExtensions', '.mat');
trainLabelsDatastore = arrayDatastore(imageDataTrain{:, {'DfSim', 'fSim', 'DsSim'}});
trainCombinedDatastore = combine(trainImds, trainLabelsDatastore);
valImds = fileDatastore(imageDataVal.imageFilePath, ...
'ReadFcn' , @(filename) double(load(filename).S_single), ...
'FileExtensions', '.mat');
valLabelsDatastore = arrayDatastore(imageDataVal{:, {'DfSim', 'fSim', 'DsSim'}});
valCombinedDatastore = combine(valImds, valLabelsDatastore);
testImds = fileDatastore(imageDataTest.imageFilePath, ...
'ReadFcn' , @(filename) double(load(filename).S_single), ...
'FileExtensions', '.mat');
testLabelsDatastore = arrayDatastore(imageDataTest{:, {'DfSim', 'fSim', 'DsSim'}});
testCombinedDatastore = combine(testImds, testLabelsDatastore);
%% Define the network layers
lgraph = layerGraph();
Layers = [
imageInputLayer([132 132 6],"Name","imageinput","Normalization","none")
convolution2dLayer([1 1],32,"Name","conv_1","Padding","same")
batchNormalizationLayer("Name","batchnorm_1")
leakyReluLayer("Name","relu_1")
dropoutLayer(0.02,"Name","dropout_1")
convolution2dLayer([3 3],32,"Name","conv_2","Padding","same")
leakyReluLayer("Name","relu_2")
dropoutLayer(0.02,"Name","dropout_2")
convolution2dLayer([1 1],64,"Name","conv_3","Padding","same")
batchNormalizationLayer("Name","batchnorm_2")
leakyReluLayer("Name","relu_3")
dropoutLayer(0.02,"Name","dropout_3")
convolution2dLayer([3 3],64,"Name","conv_4","Padding","same")
leakyReluLayer("Name","relu_4")
dropoutLayer(0.02,"Name","dropout_4")
convolution2dLayer([1 1],128,"Name","conv_5","Padding","same")
batchNormalizationLayer("Name","batchnorm_3")
leakyReluLayer("Name","relu_5")
dropoutLayer(0.02,"Name","dropout_5")
convolution2dLayer([3 3],128,"Name","conv_6","Padding","same")
leakyReluLayer("Name","relu_6")
dropoutLayer(0.02,"Name","dropout_6")
convolution2dLayer([1 1],64,"Name","conv_7","Padding","same")
batchNormalizationLayer("Name","batchnorm_4")
leakyReluLayer("Name","relu_7")
dropoutLayer(0.02,"Name","dropout_7")
convolution2dLayer([3 3],64,"Name","conv_8","Padding","same")
leakyReluLayer("Name","relu_8")
dropoutLayer(0.02,"Name","dropout_8")
convolution2dLayer([1 1],32,"Name","conv_9","Padding","same")
batchNormalizationLayer("Name","batchnorm_5")
leakyReluLayer("Name","relu_9")
dropoutLayer(0.02,"Name","dropout_9")
convolution2dLayer([3 3],32,"Name","conv_10","Padding","same")
leakyReluLayer("Name","relu_10")
dropoutLayer(0.02,"Name","dropout_10")
convolution2dLayer([1 1],4,"Name","conv_11","Padding","same")
sigmoidLayer("Name","sigmoid")];
lgraph = addLayers(lgraph,Layers);
dlnet = dlnetwork(lgraph);
plot(lgraph);
%% Training loop
numEpochs = 200;
miniBatchSize = 10;
initialLearnRate = 0.01;
decay = 0.00001;
gradDecay = 0.9;
sqGradDecay = 0.999;
mbq = minibatchqueue(trainCombinedDatastore,...
'MiniBatchSize', miniBatchSize,...
'MiniBatchFormat', {'SSCB', 'CB'}, ...
'OutputAsDlarray', [1, 1],...
'OutputEnvironment', 'auto');
averageGrad = [];
averageSqGrad = [];
numObservationsTrain = imageIdx;
numIterationsPerEpoch = ceil(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
plots = 'training-progress';
if strcmp(plots, 'training-progress')
figure
lineLossTrain = animatedline;
xlabel("Total Iterations")
ylabel("Loss")
end
epoch = 0;
iteration = 0;
start = tic;
% Loop over epochs.
while epoch < numEpochs
epoch = epoch + 1;
% Shuffle data.
shuffle(mbq);
% Loop over mini-batches.
while hasdata(mbq)
iteration = iteration + 1;
% Read mini-batch of data.
[dlX, dlT] = next(mbq);
[loss, gradients, state] = dlfeval(@modelLoss,dlnet,dlX);
dlnet.State = state;
% Determine learning rate for time-based decay learning rate schedule.
learnRate = initialLearnRate/(1 + decay*iteration);
% Update network parameters
[dlnet,averageGrad,averageSqGrad] = adamupdate(dlnet,gradients,averageGrad,averageSqGrad,...
iteration, learnRate, gradDecay, sqGradDecay);
% Extract weights of first convolution layer
conv1Weights = dlnet.Layers(2).Weights;
% Print or save the weights
disp('Weights of conv_1 layer:');
disp(conv1Weights);
if strcmp(plots, 'training-progress')
D = duration(0,0,toc(start),'Format','hh:mm:ss');
addpoints(lineLossTrain, iteration, double(gather(extractdata(loss))));
title("Epoch: " + epoch + " , Elapsed: " + string(D));
drawnow
end
end
end
%% Custom loss function
function [loss, gradients,state] = modelLoss(dlnet, dlX)
% Forward data through network.
[dlY, state] = forward(dlnet, dlX);
% Calculate parameter maps
fMap = dlY(:,:,1,:).*0.5;
DfMap = dlY(:,:,2,:).*0.107;
S0Map = (dlY(:,:,3,:).*0.6) + 0.7;
DsMap = dlY(:,:,4,:).*0.0017;
% diffusion weightings
dlB = [50 100 150 250 500 800];
% Use model outputs to predict the diffusion signal for each image
% in mini batch
Spred = zeros(size(dlX));
for b = 1:length(dlB)
Spred(:,:,b,:) = S0Map .* (fMap.*exp(-dlB(b).*DfMap) + (1 - fMap).*exp(-dlB(b).*DsMap));
end
% Convert Spred to dlarray
Spred = dlarray(Spred, 'SSCB');
% Calculate the mse loss
loss = mse(Spred, dlX);
% Calculate gradients of loss with respect to learnable parameters.
gradients = dlgradient(loss, dlnet.Learnables);
end