function [gain,shGain,er,sh,w,predictions] = GLMgain(source,target,varargin)
% Computes the gain of a GLM predicting the vector 'target' from the matrix 'source'
% (each row is an instance, each column is a different variable). The gain is
% defined as the median error of a GLM performance on shuffled data divided
% by the median of the true error. The reason for using the median rather than
% the mean is because when the errors are not normally distributed, this can
% lead to gains that are systematically higher than 1; taking the median avoid.
% =========================================================================
% Properties Values
% -------------------------------------------------------------------------
% 'link' link function to use in GLM (default = 'log')
% 'dist' distribution to specify in GLM (default = 'poisson',
% i.e. minimize devFun = @(mu,y) (y - mu).^2)
% 'nSets' number of sets the data will be divided into for
% the cross-validation step
% 'nIterations' number of times the algorhithm will estimate the
% prediction and the shuffle of the prediction
% 'minEvents' a minumum number of events that the target neuron
% needs to fire in to be predicted (default = 10);
%
% Copyright (C) 2017 by Ralitsa Todorova
%
% This program is free software; you can redistribute it and/or modify
% it under the terms of the GNU General Public License as published by
% the Free Software Foundation; either version 3 of the License, or
% (at your option) any later version.
% Default values:
nSets = 5;
nIterations = 1000;
minEvents = 5;
link = 'log';
dist = 'poisson';
for i = 1:2:length(varargin),
if ~ischar(varargin{i}),
error(['Parameter ' num2str(i+2) ' is not a property (type ''help GLMgain'' for details).']);
end
switch(lower(varargin{i})),
case 'link',
link = varargin{i+1};
case 'dist',
dist = varargin{i+1};
case 'nsets',
nSets = varargin{i+1};
if ~isscalar(nSets) || mod(nSets,1)>0
error('Incorrect value for property ''nSets'' (type ''help GLMgain'' for details).');
end
case 'niterations',
nIterations = varargin{i+1};
if ~isscalar(nIterations) || mod(nIterations,1)>0
error('Incorrect value for property ''nIterations'' (type ''help GLMgain'' for details).');
end
case 'minevents',
minEvents = varargin{i+1};
if ~isscalar(minEvents) || mod(minEvents,1)>0
error('Incorrect value for property ''minEvents'' (type ''help GLMgain'' for details).');
end
otherwise,
error(['Unknown property ''' num2str(varargin{i}) ''' (type ''help GLMgain'' for details).']);
end
end
notempty = sum(target>0)>minEvents; % at least 10 events
target(:,~notempty) = [];
[nEvents,nUnits] = size(target);
state = warning;
state = state(1).state;
warning('off');
gain = nan(nUnits,1);
shGain = nan(nUnits,nIterations);
er = nan(nUnits,1);
sh = nan(nUnits,1);
w = nan(nUnits,size(source,2)+1,nSets);
predictions = nan(size(target));
for unit = 1:nUnits,%[test(i,2),~,~,~,~,pro(:,i)] = GLMgain(rand(size(s)),double(t>0),'nsets',sum(t>0),'dist','normal','link','identity','minevents',5);
%% Split data into nSets balanced sets
setID = nan(nEvents,1);
zero = target(:,unit)==0;
% initialise (number sets sequentially)
setID(zero) = ceil(linspace(1/sum(zero),1,sum(zero))*nSets); % we keep zero and non-zero sets separately for balance
setID(~zero) = ceil(linspace(1/sum(~zero),1,sum(~zero))*nSets);
% scramble respective sets sets
setID(zero) = Scramble(setID(zero));
setID(~zero) = Scramble(setID(~zero));
shuffled = nan(nSets,nIterations);
errors = nan(nSets,1);
for set = 1:nSets,
lastwarn('');
weights = glmfit(source(setID~=set,:),target(setID~=set,unit),dist,'link',link);
if strcmp(lastwarn,'Iteration limit reached.') % if the algorhithm failed to converge
continue
end
w(unit,:,set) = weights;
if strcmp(link,'log'), prediction = exp(source(setID==set,:)*weights(2:end) + weights(1));
else, prediction = source(setID==set,:)*weights(2:end) + weights(1);
end
errors(set,1) = nanmean(abs(target(setID==set,unit) - prediction));
for iteration = 1:nIterations,
% get randomly ordered numbers
[~,indices] = sort(rand(sum(setID==set),1));
shuffled(set,iteration) = nanmean(abs(target(setID==set,unit) - prediction(indices)));
end
end
er(unit,1) = nanmean(errors);
sh(unit,1) = nanmedian(nanmean(shuffled));
gain(unit,1) = nanmedian(nanmean(shuffled))./nanmean(errors(:));
shGain(unit,:) = bsxfun(@rdivide,nanmedian(nanmean(shuffled)),nanmean(shuffled));
predictions(setID==set,unit) = prediction;
end
warning(state);
if sum(~notempty)>0,
gain0 = gain;
shGain0 = shGain;
er0 = er;
sh0 = sh;
w0 = w;
gain = nan(numel(notempty),1);er = nan(numel(notempty),1);
sh = nan(numel(notempty),1); shGain = nan(numel(notempty),nIterations);
w = nan(numel(notempty),size(source,2)+1,nSets);
gain(notempty) = gain0;er(notempty) = er0;sh(notempty) = sh0;
shGain(notempty,:) = shGain0;
w(notempty,:,:) = w0;
end