function varargout = getparm( net, expectations, varargin )
%  GETPARM.M
%
%  [ v1, v2, v3, ... ] = getparm( net, expectations, 'v1', 'v2', 'v3', ... )
%    returns variables 'v1', 'v2', 'v3', ... of the python net.
%    All the nodes corresponding to the same variable 'v1' are
%    collected in one field. For example, python nodes with labels
%    'v1(1)',..., 'v2(n)' would be stored in the same output structure v1.
%
%  The argument expectations specifies whether to store only the
%    expected values of the variables (expectations = 1) or all relevant
%    fields of the nodes (expectations=0).
%
%  parms = getparm( net, expectations, 'v1', 'v2', 'v3', ... ) collects
%    the returned variables in the fields of the output argument parms.
%
parm = struct([]);

for i = 1:length(varargin)
    % disp( net{i}.label )
    
    varname = varargin{i};
    found = 0;
    
    % Look for the nodes corresponding to variable varname
    for j = 2:length(net)
        
        % nodename - the name of the variable
        % nodeix   - the indices of the node
        ptr = find( net{j}.label == '(' );
        if isempty( ptr )
            nodename = net{j}.label;
        else
            nodename = net{j}.label( 1:ptr-1 );
        end
        nodeix = net{j}.label( ptr:end );

        if strcmp( varname, nodename )
            if expectations
                parm = TakeExpectation( net{j}, nodename, nodeix, parm );
            else
                parm = SetParmFields( net{j}, nodename, nodeix, parm );
            end
            found = 1;
        end
    end
    if ~found,
      error(sprintf('Variable %s not found in net.', varname));
    end
end

if nargout == 1 & nargin > 2
    varargout{1} = parm;
else
    for i = 1:nargin-1
        varargout{i} = getfield( parm, varargin{i} );
    end
end

return


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function parm = TakeExpectation( node, nodename, nodeix, parm )

if isempty(parm)
    clear parm
end

switch node.type
case {'DiscreteDirichletV', 'DiscreteV'}

    for tx = 1:length( node.myval )
        expt(:,tx) = node.myval{tx}.val';
    end

case { 'Gaussian', 'GaussianV' }

    expt = node.myval.mean;
    
case { 'GaussNonlin', 'GaussNonlinV' }

    expt = node.myval1.mean;
    
case { 'RectifiedGaussian', 'RectifiedGaussianV' }

    expt = node.expectations.mean;
    
case { 'Dirichlet', 'MoGV' }
    expt = node.expts.mean;
    
otherwise

    if isfield( node, 'expts' )
        expt = node.expts.mean;
    elseif isfield( node, 'expectations' )
        expt = node.expectations.mean;
    elseif isfield( node, 'myval' ) & isfield( node.myval, 'mean' )
        expt = node.myval.mean;
    elseif isfield( node, 'myval' ) & isfield( node.myval, 'val' )
        expt = node.myval.val;
    else
        expt = [];
    end
    
end

nodeix = AddDimension( nodeix, size(expt) );

%disp( [ 'parm.' nodename nodeix ' = expt;' ] );
eval( [ 'parm.' nodename nodeix ' = expt;' ] );

return


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function parm = SetParmFields( node, nodename, nodeix, parm )

if isempty(parm)
    clear parm
end

eval( [ 'parm.' nodename '.type = node.type;' ] );

switch node.type
case {'DiscreteDirichletV', 'DiscreteV' }

    for tx = 1:length( node.myval )
%        disp( [ 'parm.' nodename '.myval(:,' int2str(tx) ') ' ...
%                '= node.myval{' int2str(tx) '}.val'';' ] );
        eval( [ 'parm.' nodename '.myval(:,' int2str(tx) ') ' ...
                '= node.myval{' int2str(tx) '}.val'';' ] );
    end
    
case { 'Gaussian', 'GaussianV' }

    parm = SaveStructFields( 'myval', node.myval,...
                             nodename, nodeix, parm );
    
case { 'GaussNonlin', 'GaussNonlinV' }

    parm = SaveStructFields( 'myval1', node.myval,...
                             nodename, nodeix, parm );
    
case 'Dirichlet'
    fldix = AddDimension( nodeix, size(node.myval) );
    %disp( [ 'parm.' nodename '.myval' fldix ' = node.myval;' ] );
    eval( [ 'parm.' nodename '.myval' fldix ' = node.myval;' ] );
    
    parm = SaveStruct( 'expts', node.expts, nodename, nodeix, parm );
    
case 'MoGV'
    
    parm = SaveStruct( 'expts', node.expts, nodename, nodeix, parm );

    fldix = nodeix;
    fldix( find( fldix == '(' ) ) = '{';
    fldix( find( fldix == ')' ) ) = '}';
    
    eval( [ 'parm.' nodename '.myval' fldix ' = node.myval;' ] );

end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function parm = SaveStruct( strname, strvalue,...
                            nodename, nodeix, parm )

if isempty(parm)
    clear parm
end

strfields = fieldnames( strvalue );
for fx = 1:length( strfields )
    
    fldname = strfields{fx};
    fldvalue = getfield( strvalue, fldname );
    
    if ~isempty(fldvalue)
            
        fldix = AddDimension( nodeix, size(fldvalue) );
        
%        disp( [ 'parm.' nodename '.' strname '.' fldname fldix...
%                ' = strvalue.' fldname ';' ] );
        eval( [ 'parm.' nodename '.' strname '.' fldname fldix...
                ' = strvalue.' fldname ';' ] );
    end                
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function parm = SaveStructFields( strname, strvalue,...
                                  nodename, nodeix, parm )
if isempty(parm)
    clear parm
end

strfields = fieldnames( strvalue );
for fx = 1:length( strfields )
    
    fldname = strfields{fx};
    fldvalue = getfield( strvalue, fldname );
    
    if ~isempty(fldvalue)
            
        fldix = AddDimension( nodeix, size(fldvalue) );
        
%        disp( [ 'parm.' nodename '.' fldname fldix...
%                ' = strvalue.' fldname ';' ] );
        eval( [ 'parm.' nodename '.' fldname fldix...
                ' = strvalue.' fldname ';' ] );
    end                
end



%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%  Add another dimension for the vector parameter
function nodeix = AddDimension( nodeix, size_value )

if size_value(2) > 1 & ~isempty(nodeix)
    
    if size_value(1) == 1
        nodeix = [ nodeix(1:end-1) ',:)' ];
    else
        nodeix( find( nodeix == '(' ) ) = '{';
        nodeix( find( nodeix == ')' ) ) = '}';
    end
end
