% Estimation.m

% This program estimates factor demand functions using data on factor 
% prices, factor input levels, costs, and output levels at the firm level.
% Factor inputs are estimated using nonparametric convex regression.

%  The program calls the following subroutines:
%  active.m
%  project.m
%  fle.m,  feq.m, flt.m

global N
global M
global FCOMPTOL
global Kindex
global m

%diary CG.out

fprintf('STARTING Estimation\n');
format long;
FCOMPTOL = 1e-7;

% loading data
fprintf('Loading data...\n');
load_Christensen;  	% see load_Christensen for data description

% Initializations
fprintf('initializing...\n');   

[N,M]  = size(p);
Kindex = zeros(N*(N-1),2);
m      = zeros(N+1,1);
activ0 = 1e30;
activ1 = 1e29;
y      = y';
dep    = y(:);
fit1   = dep;
fit0   = zeros(size(fit1));
iter   = 0;
shift  = sparse(zeros(N*M,N+1));

% Create index of all possible constraints 
k = 0;
r = 1;
while r <= N;
   h = 1;
   while h <= N;
      if all(q(h,:) >= q(r,:)) & h~=r
         k = k+1;
         Kindex(k,:) = [h r];
      end
      h = h+1;
   end
   r = r+1;
   m(r) = k;
end 
Kindex = Kindex(1:k,:);

fprintf('ready.\n');

fprintf('==========================================================================\n');
fprintf('--------------------------------------------------------------------------\n');
fprintf('Problem Description\n');
fprintf('\t\t Number of observations:  %g\n',N);
fprintf('\t\t Number of factor inputs:  %g\n',M);
fprintf('\t Cost function restrictions\n');
fprintf('\t\t Concavity restrictions:  %g\n',k);
fprintf('\t\t Monotonicity restrictions:  %g\n',M*N);
fprintf('--------------------------------------------------------------------------\n');
fprintf('==========================================================================\n');

% Begin to solve the estimation problem subject to concavity constraints
try_active = 0;
soln = 0;
lenf0  = inf;
while ~all(feq(fit1,fit0))       
   iter   = iter+1;
   same_R = 1;
   fit0   = fit1;
   R      = sparse([]);
   lambda = [];           
   Rindex = 0;            
   Oindex = [];           
   try_active = (activ0<=activ1);
   soln_lag = soln;  	
   if try_active
      fprintf('A final solution will be attempted after this iteration.\n');
   end
   activ0 = activ1;
   activ1 = 0;         
   r = 0;
   while r < N	 
      r = r+1;
      while m(r)==m(r+1)
         if try_active
            Rindex = [Rindex;Rindex(r)];
         end
         r = r+1;
         if r==N+1
            break;
         end
      end
      if r==N+1
         break
      end
      shift0 = shift(:,r);
      f   = dep - sum(shift(:,[1:r-1,r+1:N+1])')'; % better than adding shift0
      f   = reshape(f,M,N);
      Use = [Kindex(m(r)+1:m(r+1),1);r];
      U   = m(r+1) - m(r) + 1;
      [lambda_r,Used,U] = project(f(:,Use),p(r,:),r,U);
      if U>0
         R_r = zeros(U,N);  
         R_r(:,Use(Used)) = -eye(U);     
         R_r(:,r) = ones(U,1);
         shift_r = kron(R_r',p(r,:)')*lambda_r;
         Oindex = [Oindex; [Use(Used),r*ones(U,1)] ];
      else
         R_r = [];
         shift_r = zeros(N*M,1);
      end
      if try_active
         R = [R;kron(R_r,p(r,:))];
         lambda = [lambda;lambda_r];
         Rindex = [Rindex;Rindex(r)+U];
      end
      fit1 = f(:) - shift_r;
      shift(:,r) = shift_r;
      activ1 = activ1 + U;
      same_R = same_R*all((shift0~=0)==(shift_r~=0)); 
      lenf  = f(:)'*f(:); 
      lenf1 = fit1'*fit1;  
      if ~fle(lenf1,lenf)
          fprintf('Did not shorten in projection\n');
          keyboard
      end
      if ~fle(lenf1,lenf0)
          fprintf('Did not shorten fitted\n');
          keyboard
      end
      lenf0 = lenf1;
      clear Use f_proj;
   end      
   
   %  Impose monotonicity constraints.      
   r    = N+1;
   fit1 = dep - sum(shift(:,1:N)')';
   %  Find new shift
   Used = fit1<0;
   U    = sum(Used);
   shift(:,N+1) = fit1.*Used;
   %  Project onto new monotonicity constraints
   fit1 = fit1.*(fit1>=0);
   active_mono = U;
   activ1 = activ1 + U;
      if try_active
      if U>0
         R_r = zeros(U,M*N);
         R_r(:,Used) = -eye(U);
         R = [R;R_r];
         lambda = [lambda;-shift(Used,N+1)];
         Oindex = [Oindex; [find(Used),r*ones(U,1)] ];
      end
      Rindex = [Rindex;Rindex(r)+U];
   end

   if same_R==1;
      fprintf('Active constraints unchanged over this iteration.\n');
      if soln_lag;
         fprintf('Previous solution of active constraints yielded solution.\n');
         break;
      end
   end

   % If the solution is cycling among the same set of active constraints,
   % attempt to solve all active constraints simultaneously.
   if try_active
      fprintf('------------------------------------------------\n');
      fprintf('Iteration: %g -- leny: %g -- SSQ: %g\n', ...
               iter, fit1'*fit1,(dep-fit1)'*(dep-fit1) );
      fprintf('               %g active constraints\n',activ1);
      fprintf('------------------------------------------------\n');  
      fprintf('Attempting to solve active constraints...\n');
      [R,lambda,soln,pivot] = active(dep,R,lambda);
      if soln
         fprintf('Solved all active constraints successfully.\n');
      end
      r = 0;
      while r < N+1   
         r = r+1;
         if Rindex(r)==Rindex(r+1)
            shift(:,r) = zeros(N*M,1);
         else
            shift(:,r) = ...
            R(Rindex(r)+1:Rindex(r+1),:)'*lambda(Rindex(r)+1:Rindex(r+1));
         end
      end
      if ~all(feq(sum(shift')',R'*lambda))
         fprintf('Error accumulating shift matrix.\n');
         keyboard
      end
      fit2  = dep - sum(shift')';  % better than "dep - R'*lambda"
      lenf2 = fit2'*fit2;
      if ~fle(lenf2,lenf1)
          fprintf('active failed\n');
          keyboard
      end
      lenf0 = lenf2;
      fit1  = fit2;
      lambda = full(lambda);
      activ1 = sum(lambda>0);
      if pivot>Rindex(N+1)
         active_mono = active_mono - 1
      end
   end
   
   fprintf('------------------------------------------------\n');
   fprintf('Iteration: %g -- leny: %g -- SSQ: %g\n', ...
           iter, fit1'*fit1,(dep-fit1)'*(dep-fit1) );
   fprintf('               %g active constraints\n',activ1);
   fprintf('               %g concavity\n',activ1-active_mono);
   fprintf('               %g monotonicity\n',active_mono);
   fprintf('------------------------------------------------\n');  

end

if all(feq(fit0,fit1)) & ~soln_lag
   fprintf('Fit stationary over last iteration ==> solution found.\n');
end 

save results_cg fit1 y p q;
diary off;