Files
simpeg/code/EldadsCode/tools.py
T

223 lines
6.0 KiB
Python

import numpy;
import cmath;
import math;
def prod(arg):
""" returns the product of elements in arg.
arg can be list, tuple, set, and array with numerical values. """
ret = 1;
for i in range(0,len(arg)):
ret = ret * arg[i];
return ret;
def allIndices(dim):
""" From the given shape of dimenions (e.g. (2,3,4)),
generate a numpy.array of all, sorted indices."""
length = len(dim);
sub = numpy.arange(dim[length-1]).reshape(dim[length-1],1);
for d in range(length-2, -1, -1):
for i in range(0, dim[d]):
temp = numpy.ndarray([len(sub), 1]);
temp.fill(i);
temp = numpy.concatenate((temp,sub), axis=1);
if(i == 0):
newsub = temp;
else:
newsub = numpy.concatenate((newsub, temp), axis = 0);
sub = newsub;
return sub;
def find(nda, obj):
"""returns the index of the obj in the given nda(ndarray, list, or tuple)"""
for i in range(0, len(nda)):
if(nda[i] == obj):
return i;
return -1;
def notin(n, vector):
"""returns a numpy.array object that contains
elements in [0,1, ... n-1] but not in vector."""
ret = numpy.arange(n).tolist();
for i in vector:
if (0 <= i and i < n):
ret.remove(i);
return numpy.array(ret);
def getelts(nda, indices):
"""From the given nda(ndarray, list, or tuple), returns the list located at the given indices"""
ret = [];
for i in indices:
ret.extend([nda[i]]);
return numpy.array(ret);
def sub2ind(shape, subs):
""" From the given shape, returns the index of the given subscript"""
revshp = list(shape);
revshp.reverse();
mult = [1];
for i in range(0, len(revshp)-1):
mult.extend([mult[i]*revshp[i]]);
mult.reverse();
mult = numpy.array(mult).reshape(len(mult),1);
idx = numpy.dot((subs) , (mult));
return idx;
def ind2sub(shape, ind):
""" From the given shape, returns the subscrips of the given index"""
revshp = [];
revshp.extend(shape);
revshp.reverse();
mult = [1];
for i in range(0, len(revshp)-1):
mult.extend([mult[i]*revshp[i]]);
mult.reverse();
mult = numpy.array(mult).reshape(len(mult));
sub = [];
for i in range(0,len(shape)):
sub.extend([math.floor(ind / mult[i])]);
ind = ind - (math.floor(ind/mult[i]) * mult[i]);
return sub;
def tt_dimscehck(dims, N, M = None, exceptdims = False):
"""Checks whether the specified dimensions are valid in a tensor of N-dimension.
If M is given, then it will also retuns an index for M multiplicands.
If exceptdims == True, then it will compute for the dimensions not specified."""
# if exceptdims is true
if(exceptdims):
dims = listdiff(range(0,N), dims);
#check vals in between 0 and N-1
for i in range(0, len(dims)):
if(dims[i] < 0 or dims[i] >= N):
raise ValueError("invalid dimensions specified");
# number of dimensions in dims
p = len(dims);
sdims = [];
sdims.extend(dims);
sdims.sort();
#indices of the elements in the sorted array
sidx = [];
#table that denotes whether the index is used
table = numpy.ndarray([len(sdims)]);
table.fill(0);
for i in range(0, len(sdims)):
for j in range(0, len(dims)):
if(sdims[i] == dims[j] and table[j] == 0):
sidx.extend([j]);
table[j] = 1;
break;
if (M == None):
return sdims
if(M > N):
raise ValueError("Cannot have more multiplicands than dimensions");
if(M != N and M != p):
raise ValueError("invalid number of multiplicands");
if(M == p):
vidx = sidx;
else:
vidx = sdims;
return (sdims, vidx);
def listtimes(list, c):
"""multiplies the elements in the list by the given scalar value c"""
ret = []
for i in range(0, len(list)):
ret.extend([list[i]]*c);
return ret;
def listdiff(list1, list2):
"""returns the list of elements that are in list 1 but not in list2"""
if(list1.__class__ == numpy.ndarray):
list1 = list1.tolist();
if(list2.__class__ == numpy.ndarray):
list2 = list2.tolist();
ret = []
for i in range(0,len(list1)):
ok = true
for j in range(0, len(list2)):
if(list[i] == list[j]):
ok = false;
break;
if(ok):
ret.extend([list[i]]);
return ret;
def tt_subscheck(subs):
"""Check whether the given list of subscripts are valid. Used for sptensor"""
isOk = True;
if(subs.size == 0):
isOk = True;
elif(subs.ndim != 2):
isOk = False;
else:
for i in range(0, (subs.size / subs[0].size)):
for j in range(0, (subs[0].size)):
val = subs[i][j];
if( cmath.isnan(val) or cmath.isinf(val) or val < 0 or val != round(val) ):
isOk = False;
if(not isOk):
raise ValueError("Subscripts must be a matrix of non-negative integers");
return isOk;
def tt_valscheck(vals):
"""Check whether the given list of values are valid. Used for sptensor"""
isOk = True;
if(vals.size == 0):
isOk = True;
elif(vals.ndim != 2 or vals[0].size != 1):
isOk = False;
if(not isOk):
raise ValueError("values must be a column array");
return isOk;
def tt_sizecheck(size):
"""Check whether the given size is valid. Used for sptensor"""
size = numpy.array(size);
isOk = True;
if(size.ndim != 1):
isOk = False;
else:
for i in range(0, len(size)):
val = size[i];
if(cmath.isnan(val) or cmath.isinf(val)
or val <= 0 or val != round(val)):
isOk = False;
if(not isOk):
raise ValueError("size must be a row vector of real positive integers");
return isOk;