function [gain,shGain,er,sh,w,predictions] = GLMgain(source,target,varargin) % Computes the gain of a GLM predicting the vector 'target' from the matrix 'source' % (each row is an instance, each column is a different variable). The gain is % defined as the median error of a GLM performance on shuffled data divided % by the median of the true error. The reason for using the median rather than % the mean is because when the errors are not normally distributed, this can % lead to gains that are systematically higher than 1; taking the median avoid. % ========================================================================= % Properties Values % ------------------------------------------------------------------------- % 'link' link function to use in GLM (default = 'log') % 'dist' distribution to specify in GLM (default = 'poisson', % i.e. minimize devFun = @(mu,y) (y - mu).^2) % 'nSets' number of sets the data will be divided into for % the cross-validation step % 'nIterations' number of times the algorhithm will estimate the % prediction and the shuffle of the prediction % 'minEvents' a minumum number of events that the target neuron % needs to fire in to be predicted (default = 10); % % Copyright (C) 2017 by Ralitsa Todorova % % This program is free software; you can redistribute it and/or modify % it under the terms of the GNU General Public License as published by % the Free Software Foundation; either version 3 of the License, or % (at your option) any later version. % Default values: nSets = 5; nIterations = 1000; minEvents = 5; link = 'log'; dist = 'poisson'; for i = 1:2:length(varargin), if ~ischar(varargin{i}), error(['Parameter ' num2str(i+2) ' is not a property (type ''help GLMgain'' for details).']); end switch(lower(varargin{i})), case 'link', link = varargin{i+1}; case 'dist', dist = varargin{i+1}; case 'nsets', nSets = varargin{i+1}; if ~isscalar(nSets) || mod(nSets,1)>0 error('Incorrect value for property ''nSets'' (type ''help GLMgain'' for details).'); end case 'niterations', nIterations = varargin{i+1}; if ~isscalar(nIterations) || mod(nIterations,1)>0 error('Incorrect value for property ''nIterations'' (type ''help GLMgain'' for details).'); end case 'minevents', minEvents = varargin{i+1}; if ~isscalar(minEvents) || mod(minEvents,1)>0 error('Incorrect value for property ''minEvents'' (type ''help GLMgain'' for details).'); end otherwise, error(['Unknown property ''' num2str(varargin{i}) ''' (type ''help GLMgain'' for details).']); end end notempty = sum(target>0)>minEvents; % at least 10 events target(:,~notempty) = []; [nEvents,nUnits] = size(target); state = warning; state = state(1).state; warning('off'); gain = nan(nUnits,1); shGain = nan(nUnits,nIterations); er = nan(nUnits,1); sh = nan(nUnits,1); w = nan(nUnits,size(source,2)+1,nSets); predictions = nan(size(target)); for unit = 1:nUnits,%[test(i,2),~,~,~,~,pro(:,i)] = GLMgain(rand(size(s)),double(t>0),'nsets',sum(t>0),'dist','normal','link','identity','minevents',5); %% Split data into nSets balanced sets setID = nan(nEvents,1); zero = target(:,unit)==0; % initialise (number sets sequentially) setID(zero) = ceil(linspace(1/sum(zero),1,sum(zero))*nSets); % we keep zero and non-zero sets separately for balance setID(~zero) = ceil(linspace(1/sum(~zero),1,sum(~zero))*nSets); % scramble respective sets sets setID(zero) = Scramble(setID(zero)); setID(~zero) = Scramble(setID(~zero)); shuffled = nan(nSets,nIterations); errors = nan(nSets,1); for set = 1:nSets, lastwarn(''); weights = glmfit(source(setID~=set,:),target(setID~=set,unit),dist,'link',link); if strcmp(lastwarn,'Iteration limit reached.') % if the algorhithm failed to converge continue end w(unit,:,set) = weights; if strcmp(link,'log'), prediction = exp(source(setID==set,:)*weights(2:end) + weights(1)); else, prediction = source(setID==set,:)*weights(2:end) + weights(1); end errors(set,1) = nanmean(abs(target(setID==set,unit) - prediction)); for iteration = 1:nIterations, % get randomly ordered numbers [~,indices] = sort(rand(sum(setID==set),1)); shuffled(set,iteration) = nanmean(abs(target(setID==set,unit) - prediction(indices))); end end er(unit,1) = nanmean(errors); sh(unit,1) = nanmedian(nanmean(shuffled)); gain(unit,1) = nanmedian(nanmean(shuffled))./nanmean(errors(:)); shGain(unit,:) = bsxfun(@rdivide,nanmedian(nanmean(shuffled)),nanmean(shuffled)); predictions(setID==set,unit) = prediction; end warning(state); if sum(~notempty)>0, gain0 = gain; shGain0 = shGain; er0 = er; sh0 = sh; w0 = w; gain = nan(numel(notempty),1);er = nan(numel(notempty),1); sh = nan(numel(notempty),1); shGain = nan(numel(notempty),nIterations); w = nan(numel(notempty),size(source,2)+1,nSets); gain(notempty) = gain0;er(notempty) = er0;sh(notempty) = sh0; shGain(notempty,:) = shGain0; w(notempty,:,:) = w0; end