From 8f39bb7f550594614da52481620b54fdf86b211a Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 27 Feb 2026 14:34:34 -0500 Subject: [PATCH 01/38] Initial integration of NUG into abinitio class. --- src/aspire/abinitio/__init__.py | 1 + src/aspire/abinitio/commonline_nug.py | 789 ++++++++++++++++++++++++++ 2 files changed, 790 insertions(+) create mode 100644 src/aspire/abinitio/commonline_nug.py diff --git a/src/aspire/abinitio/__init__.py b/src/aspire/abinitio/__init__.py index d8e2cfddcd..4b6248c4c7 100644 --- a/src/aspire/abinitio/__init__.py +++ b/src/aspire/abinitio/__init__.py @@ -6,6 +6,7 @@ ) from .commonline_base import Orient3D from .commonline_matrix import CLOrient3D +from .commonline_nug import CommonlineNUG from .commonline_sdp import CommonlineSDP from .commonline_lud import CommonlineLUD from .commonline_irls import CommonlineIRLS diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py new file mode 100644 index 0000000000..eb092e8604 --- /dev/null +++ b/src/aspire/abinitio/commonline_nug.py @@ -0,0 +1,789 @@ +import logging +import time + +import numpy as np +import cupy as cp +from scipy.special import sph_harm, binom, jacobi, factorial +from scipy.io import loadmat +from scipy.spatial.transform import Rotation as spr + +from aspire.abinitio import CLOrient3D +from aspire.nufft import nufft +from aspire.numeric import fft, xp +from aspire.operators import wemd_embed + +logger = logging.getLogger(__name__) + + +class CommonlineNUG(CLOrient3D): + """ + Class to estimate 3D orientations using non-uqique games. + """ + + def __init__( + self, + src, + symmetry=None, + n_rad=None, + n_theta=360, + max_shift=0.15, + shift_step=1, + mask=True, + Lmax=12, + loss='l1', + T=36, + max_iter=501, + rho=0.05, + ratio=1, + factor=1.0, + mult=1.5, + Ngrid=16317, + Nstep_yI=10, + **kwargs, + ): + """ + Initialize object for estimating 3D orientations for molecules with C3 and C4 symmetry. + + :param src: The source object of 2D denoised or class-averaged images with metadata + :param symmetry: A string, ie. 'C3', indicating the symmetry type. + :param n_rad: The number of points in the radial direction + :param n_theta: The number of points in the theta direction. Default = 360. + """ + + super().__init__( + src, + n_rad=n_rad, + n_theta=n_theta, + max_shift=max_shift, + shift_step=shift_step, + mask=mask, + **kwargs, + ) + + self.Lmax = Lmax + self.loss = loss + self.T = T + self.max_iter = max_iter + self.rho = rho + self.ratio = ratio + self.factor = factor + self.mult = mult + self.Ngrid = Ngrid + self.Nstep_yI = Nstep_yI + self.sym = symmetry + + def estimate_rotations(self): + sym_euler,S=self.Symmetry_Euler(self.sym) + imgs = self.src.images[:] + C = self.compute_coeff(imgs, self.loss, self.Lmax, T=self.T) + X_est = self.admm_sym_J( + self.sym, + C, + self.Lmax, + self.n_img, + self.Ngrid, + self.max_iter, + self.rho, + self.ratio, + self.factor, + self.mult, + self.Nstep_yI, + ) + + R_est,Euler_est = self.euler_est(X_est[0],X_est[S-1],self.sym,self.n_img) + self.rotations = R_est + + return R_est + + ####################### + # Compute Coeffs Step # + ####################### + + def compute_coeff(self, Img, loss, Lmax, T): + # compute the coefficient matrix + N,L,_=Img.shape; n_theta=360 + angular_sampling = np.arange(0, 360, 1) + line_proj=np.zeros((L,n_theta,N)) + Img_pft=np.zeros((L,n_theta,N),dtype=complex) + Img = Img.asnumpy() + for n in range(N): + line_proj[:,:,n],Img_pft[:,:,n]=self.fast_radon_transform(Img[n], angular_sampling) + + dim_wave=len(wemd_embed(line_proj[:,0,0])) + WE=np.zeros((dim_wave,n_theta,N)) + for i in range(N): + for theta in range(n_theta): + WE[:,theta,i]=wemd_embed(line_proj[:,theta,i]) + + + def fij(alpha,gamma,i,j,loss): + if loss=='l1': + Ii_hat=Img_pft[:,:,i]; Ij_hat=Img_pft[:,:,j] + idxi=np.round((alpha-np.pi/2)*n_theta/2/np.pi)% n_theta + idxj=np.round((-gamma-np.pi/2)*n_theta/2/np.pi)% n_theta + + Si=Ii_hat[:,int(idxi)]; Sj=Ij_hat[:,int(idxj)] + return np.linalg.norm(Si-Sj,1) + + if loss=='wemd': + idxi=np.round((alpha-np.pi/2)*n_theta/2/np.pi)% n_theta + idxj=np.round((-gamma-np.pi/2)*n_theta/2/np.pi)% n_theta + + Si=WE[:,int(idxi),i]; Sj=WE[:,int(idxj),j] + return np.linalg.norm(Si-Sj,1) + + alpha_grid=np.arange(2*T)*np.pi/T + beta_grid=(2*np.arange(2*T)+1)*np.pi/4/T + gamma_grid=np.arange(2*T)*np.pi/T + + bT=np.zeros(2*T); + for l in range(2*T): + ss=0 + for m in range(T): + ss=ss+np.sin(beta_grid[l]*(2*m+1))/(2*m+1) + bT[l]=2/T*np.sin(beta_grid[l])*ss + + BTK=[]; + for k in range(1,Lmax+1): + dk=2*k+1 + btk=np.zeros((dk,dk)) + for l in range(2*T): btk=btk+bT[l]*self.Wd(k,beta_grid[l]) + BTK.append(btk.T) + + + def fijhat_k(k,F): + dk=2*k+1; + + exp_alpha_grid=np.zeros((2*T,dk),dtype=complex) + for m in range(-k,k+1): + exp_alpha_grid[:,m+k]=np.exp(1j*m*alpha_grid) + + exp_gamma_grid=np.zeros((2*T,dk),dtype=complex) + for m in range(-k,k+1): + exp_gamma_grid[:,m+k]=np.exp(1j*m*gamma_grid) + + S=(exp_alpha_grid.T@F@exp_gamma_grid).T + fhat=BTK[k-1]*S/4/T**2 + return fhat + + C=[] + for k in range(1,Lmax+1): + dk=2*k+1; C.append(np.zeros((N*dk,N*dk),dtype=complex)) + for i in range(N): + for j in range(i+1,N): + Fij=np.zeros((2*T,2*T)) + for j1 in range(2*T): + for j2 in range(2*T): + Fij[j1,j2]=fij(alpha_grid[j1],gamma_grid[j2],i,j,loss) + for k in range(1,Lmax+1): + dk=2*k+1; C[k-1][j*dk:(j+1)*dk,i*dk:(i+1)*dk]=fijhat_k(k,Fij)#*dk + for k in range(1,Lmax+1): + C[k-1]=C[k-1]+C[k-1].conj().T + + for i in range(N): + Fii=np.zeros((2*T,2*T)) + for j1 in range(2*T): + for j2 in range(2*T): + Fii[j1,j2]=fij(alpha_grid[j1],gamma_grid[j2],i,i,loss) + for k in range(1,Lmax+1): + dk=2*k+1 + C[k-1][i*dk:(i+1)*dk,i*dk:(i+1)*dk]=fijhat_k(k,Fii)#*dk + + + for k in range(1,Lmax+1): + [T,Tinv] = self.complex2real(k) + C[k-1] = np.real(np.kron(np.eye(N),Tinv) @ C[k-1]@np.kron(np.eye(N),T)) + C[k-1]=np.round(C[k-1],10) + + return C + + @staticmethod + def fast_radon_transform(array, angles, use_ramp=False): + + angles = np.array(angles).flatten() + img_size = array.shape[1] + rads = angles / 180 * np.pi + y_idx = np.arange(-img_size / 2, img_size / 2) / img_size * 2 + x_theta = y_idx[:, np.newaxis] * np.sin(rads)[np.newaxis, :] + y_theta = y_idx[:, np.newaxis] * np.cos(rads)[np.newaxis, :] + + pts = np.pi * np.vstack( + [ + x_theta.flatten(), + y_theta.flatten(), + ] + ) + pts = pts.astype(array.dtype) + + #array = array.astype(np.float32) + lines_f = nufft(array, pts).reshape((img_size, -1)) + + if img_size % 2 == 0: + lines_f[0, :] = 0 + + if use_ramp: + freqs = np.abs(np.pi * y_idx) + lines_f *= freqs[:, np.newaxis] + + projections = np.real(xp.asnumpy(fft.centered_ifft(xp.asarray(lines_f), axis=0))) + + return projections, lines_f + + + @staticmethod + def Wd(J,beta): + # compute Wigner small d matrix + d=np.zeros((2*J+1,2*J+1)); + + for m in range(-J,J+1): + for n in range(-J,J+1): + smin=max(0,m-n); smax=min(J+m,J-n); + for s in range(smin,smax+1): + mul=np.sqrt(factorial(J+m))/factorial(J+m-s)*np.sqrt(factorial(J+n))/factorial(s)*np.sqrt(factorial(J-m))/factorial(n-m+s)\ + *np.sqrt(factorial(J-n))/factorial(J-n-s) + d[n+J,m+J]+=mul*(-1)**(n-m+s)*(np.cos(beta/2))**(2*J+m-n-2*s)*(np.sin(beta/2))**(n-m+2*s); + + return d + + @staticmethod + def complex2real(ell): + # compute transformation matrices that convert complex representations to real ones + diml=2*ell+1 + Tinv=np.zeros((diml,diml),dtype=complex) + for i in range(diml): + if iell: + Tinv[i,i]=(-1)**(i-ell)/np.sqrt(2); Tinv[i,diml-1-i]=1/np.sqrt(2) + + T=np.zeros((diml,diml),dtype=complex) + for i in range(diml): + if iell: + T[i,i]=(-1)**(i-ell)/np.sqrt(2); T[i,diml-1-i]=1j*(-1)**(i-ell)/np.sqrt(2) + return T,Tinv + + + ############# + # ADMM Step # + ############# + + def admm_sym_J( + self, + sym, + C, + Lmax, + N, + Ngrid, + max_iter, + rho, + ratio, + factor, + mult=1, + Nstep_yI=20, + verbose=True, + GPU=True, + ): + # admm for symmetric case + if GPU: + import cupy as cp + else: + import numpy as cp + + C0,C1,normC,AEq,bEq,AEqAEqtinv,AI_mat_diag,AI_mat_offdiag,bI,Lambda,d0,d1,D0,D1,idx_diag,idx_offdiag,IDX_upper,IDX_lower,X0,X1,Xq,S0,S1,Sq=self.ADMM_preprocessing(C,Lmax,N,Ngrid,GPU) + + # S=int(sym[1]); sym_euler=Symmetry_Euler(sym) + rank_Ak,_=self.compute_rank(sym,Lmax); print(rank_Ak) + + # rank_Ak=cp.zeros(Lmax) + # for ell in range(Lmax): rank_Ak[ell]=np.linalg.matrix_rank(Ak(ell+1,sym_euler)) + + AE=[]; AEAETinv=[] + for k in range(1,Lmax+1): + s0=k**2; s1=(k+1)**2 + AEk=cp.zeros((1+s0+s1,2*(s0+s1))) + AEk[0,:s0]=cp.eye(k).T.reshape(-1); AEk[0,s0:s0+s1]=cp.eye(k+1).T.reshape(-1) + for count in range(1,1+s0): + AEk[count,count-1]=1; AEk[count,count-1+s0+s1]=1 + for count in range(1+s0,1+s0+s1): + AEk[count,count-1]=1; AEk[count,count-1+s1+s0]=1 + AE.append(AEk) + AEAETinv.append(np.linalg.pinv(AEk@AEk.T)) + bE=cp.zeros((Lmax+D0+D1)) + for k in range(Lmax): + bE[k+d0[k]+d1[k]:]=rank_Ak[k] + bE[k+1+d0[k]+d1[k]:k+1+d0[k+1]+d1[k]]=cp.eye(k+1).T.reshape(-1) + bE[k+1+d0[k+1]+d1[k]:k+1+d0[k+1]+d1[k+1]]=cp.eye(k+2).T.reshape(-1) + bE=np.repeat(bE[:,np.newaxis],N,axis=1) + P=[]; + for k in range(1,Lmax+1): + dk=2*k+1; Pk=cp.eye(dk) + for m in range(k): + for l in range(k-m): + Pk[(m+2*l,m+2*l+1),:]=Pk[(m+2*l+1,m+2*l),:] + P.append(Pk) + + + def fun_AE(X0,X1,Xd0,Xd1,Xq): + z=cp.zeros((Lmax+D0+D1,N)) + for k in range(Lmax): + z[k+d0[k]+d1[k]:k+1+d0[k+1]+d1[k+1]]=AE[k]@ cp.concatenate((X0[d0[k]:d0[k+1],idx_diag],X1[d1[k]:d1[k+1],idx_diag],Xd0[d0[k]:d0[k+1]],Xd1[d1[k]:d1[k+1]]),axis=0) + return z, AEq@cp.concatenate((Xq,X0[:1,idx_offdiag],X1[:4,idx_offdiag]),axis=0) + + def fun_AET(yE,yEq): + Z0=cp.zeros((D0,N*(N+1)//2)); Z1=cp.zeros((D1,N*(N+1)//2)); + Zd0=cp.zeros((D0,N)); Zd1=cp.zeros((D1,N)) + for k in range(Lmax): + s0=(k+1)**2; s1=(k+2)**2 + Ztmp=AE[k].T@yE[k+d0[k]+d1[k]:k+1+d0[k+1]+d1[k+1]] + Z0[d0[k]:d0[k+1],idx_diag]=Ztmp[:s0] + Z1[d1[k]:d1[k+1],idx_diag]=Ztmp[s0:s0+s1] + Zd0[d0[k]:d0[k+1]]=Ztmp[s0+s1:2*s0+s1] + Zd1[d1[k]:d1[k+1]]=Ztmp[2*s0+s1:2*s0+2*s1] + Zq=AEq.T@yEq + Z0[:1,idx_offdiag]=Zq[16:17]; Z1[:4,idx_offdiag]=Zq[17:] + return Z0,Z1,Zd0,Zd1,Zq[:16] + + def fun_AI(X0,X1): + z=cp.zeros((Ngrid,N*(N+1)//2)); tmp=cp.concatenate((X0,X1),axis=0); + z[:,idx_diag]=AI_mat_diag@tmp[:,idx_diag]; z[:,idx_offdiag]=AI_mat_offdiag@tmp[:,idx_offdiag] + # return AI_mat@cp.concatenate((X0,X1),axis=0) + return z + + def fun_AIT(yI): + # Z=AI_mat.T@yI + # return Z[:D0,:], Z[D0:,:] + Z=cp.zeros((D0+D1,N*(N+1)//2)); Z[:,idx_diag]=AI_mat_diag.T@yI[:,idx_diag]; Z[:,idx_offdiag]=AI_mat_offdiag.T@yI[:,idx_offdiag]; + return Z[:D0,:], Z[D0:,:] + + def update_S(C0,C1,yE,yEq,yI,X0,X1,Xd0,Xd1,Xq,rho,Lmax,N): + Z0,Z1,Zd0,Zd1,Zq=fun_AET(yE,yEq); AIT_yI0,AIT_yI1=fun_AIT(yI) + S0=C0-Z0-AIT_yI0-X0/rho; S1=C1-Z1-AIT_yI1-X1/rho; + Sd0=-Zd0-Xd0/rho; Sd1=-Zd1-Xd1/rho; Sq=-Zq-Xq/rho + tic0=time.perf_counter(); + for k in range(1,Lmax+1): + tmp=self.mat_block(S0[d0[k-1]:d0[k],:],N,k,IDX_upper,IDX_lower,idx_offdiag) + tmp=self.psd_projection(tmp) + S0[d0[k-1]:d0[k],:]=self.vec_block(tmp,N,k,IDX_upper) + + tmp=self.mat_block(S1[d1[k-1]:d1[k],:],N,k+1,IDX_upper,IDX_lower,idx_offdiag) + tmp=self.psd_projection(tmp) + S1[d1[k-1]:d1[k],:]=self.vec_block(tmp,N,k+1,IDX_upper) + toc0=time.perf_counter(); Time[0]+=toc0-tic0 + tic1=time.perf_counter(); + for k in range(1,Lmax+1): + for n in range(N): + tmp=self.transform_back_block(Sd0[d0[k-1]:d0[k],n],Sd1[d1[k-1]:d1[k],n],k,P[k-1]) + tmp=self.psd_projection(tmp) + Sd0[d0[k-1]:d0[k],n],Sd1[d1[k-1]:d1[k],n]=self.transform_block(tmp,k,P[k-1]) + toc1=time.perf_counter(); Time[1]+=toc1-tic1 + tic2=time.perf_counter(); + for count in range(N*(N-1)//2): + tmp=self.psd_projection(Sq[:,count].reshape(4,4).T) + Sq[:,count]=tmp.T.reshape(16) + toc2=time.perf_counter(); Time[2]+=toc2-tic2 + return S0,S1,Sd0,Sd1,Sq + + + def update_yE(C0,C1,X0,X1,Xd0,Xd1,Xq,S0,S1,Sd0,Sd1,Sq,yI,rho): + AIT_yI0,AIT_yI1=fun_AIT(yI) + z,zq=fun_AE(-X0/rho+C0-S0-AIT_yI0,-X1/rho+C1-S1-AIT_yI1,-Xd0/rho-Sd0,-Xd1/rho-Sd1,-Xq/rho-Sq) + yE=cp.zeros((Lmax+D0+D1,N)) + for k in range(Lmax): + yE[k+d0[k]+d1[k]:k+1+d0[k+1]+d1[k+1]]=AEAETinv[k]@(bE[k+d0[k]+d1[k]:k+1+d0[k+1]+d1[k+1]]/rho+z[k+d0[k]+d1[k]:k+1+d0[k+1]+d1[k+1]]) + yEq=AEqAEqtinv@(bEq/rho+zq) + return yE,yEq + + + def update_yI(C0,C1,X0,X1,S0,S1,yE,yEq,yI,rho,Lambda): + Z0,Z1,_,_,_=fun_AET(yE,yEq); AIT_yI0,AIT_yI1=fun_AIT(yI) + tmp=fun_AI(-X0/rho+C0-S0-Z0-AIT_yI0,-X1/rho+C1-S1-Z1-AIT_yI1) + yI=yI+bI/rho/Lambda+tmp/Lambda + yI=np.maximum(yI,0) + return yI + + def update_X(C0,C1,X0,X1,Xd0,Xd1,Xq,yE,yEq,yI,S0,S1,Sd0,Sd1,Sq,rho): + Z0,Z1,Zd0,Zd1,Zq=fun_AET(yE,yEq); AIT_yI0,AIT_yI1=fun_AIT(yI) + tmp=S0+Z0+AIT_yI0-C0; X0=X0+mult*rho*tmp; resX0=np.linalg.norm(tmp) + tmp=S1+Z1+AIT_yI1-C1; X1=X1+mult*rho*tmp; resX1=np.linalg.norm(tmp) + tmp=Sd0+Zd0; Xd0=Xd0+mult*rho*tmp; resXd0=np.linalg.norm(tmp) + tmp=Sd1+Zd1; Xd1=Xd1+mult*rho*tmp; resXd1=np.linalg.norm(tmp) + tmp=Sq+Zq; Xq=Xq+mult*rho*tmp; resXq=np.linalg.norm(tmp) + return X0,X1,Xd0,Xd1,Xq,np.sqrt(resX0**2+resX1**2+resXd0**2+resXd1**2+resXq**2) + + def update_rho(X0,X1,Xd0,Xd1,Xq,bE,bEq,bI,res_X,rho,factor,normC): + z,zq=fun_AE(X0,X1,Xd0,Xd1,Xq) + res_eq=np.linalg.norm(z-bE)/(1+np.linalg.norm(bE))+np.linalg.norm(zq-bEq)/(1+np.linalg.norm(bEq)); + res_inq=np.linalg.norm(np.maximum(bI-fun_AI(X0,X1),0))/(1+abs(bI)*np.sqrt(Ngrid*N**2)); + p_resnorm=res_eq+res_inq + d_resnorm=res_X/(1+normC) + if d_resnorm>ratio*p_resnorm: rho=rho*factor + if d_resnorm=i: IDX_upper.append(i*N+j); + for j in range(N): + for i in range(N): + if j500: + Lambda=cp.linalg.norm(z); z=z/cp.linalg.norm(z); z=AI@(AI.T@z) + Lambda+=2000; print('Largest eigenvalue of AIAIT is approximately %1.2f'%Lambda) + # Lambda=cp.linalg.eigvalsh(AI@AI.T)[-1]; print(Lambda) + return Lambda + + def compute_rank(self, sym,Lmax): + sym_euler,S=self.Symmetry_Euler(sym); rk=cp.zeros(Lmax); A=[] + for k in range(1,Lmax+1): + dk=2*k+1; Ak=np.zeros((dk,dk),dtype=np.complex128) + for s in range(S): Ak+=self.WD(k,sym_euler[s]) + Ak=np.round(Ak/S,6); A.append(Ak); rk[k-1]=np.linalg.matrix_rank(Ak) + return rk,A + + @staticmethod + def Symmetry_Euler(sym): + if sym[0]=='C': + order=int(sym[1:]); sym_euler=np.zeros((order,3)) + for i in range(order): sym_euler[i]=spr.from_euler('zyx',[2*np.pi/order*i,0,0]).as_euler('zyz') + + if sym[0]=='D': + order=int(sym[1:]); sym_euler=np.zeros((2*order,3)) + for i in range(order): + sym_euler[i]=spr.from_euler('zyx',[2*np.pi/order*i,0,0]).as_euler('zyz') + sym_euler[i+order]=spr.from_euler('zyx',[2*np.pi/order*i,0,np.pi]).as_euler('zyz') + + if sym=='T12': + sym_euler=np.zeros((12,3)) + sym_euler[1]=spr.from_rotvec(2*np.pi/3*np.array([0,0,1])).as_euler('zyz') + sym_euler[2]=spr.from_rotvec(4*np.pi/3*np.array([0,0,1])).as_euler('zyz') + + + return sym_euler,sym_euler.shape[0] + + def WD(self, J,euler): + # compute Wigner D matrix + alpha=euler[0]; beta=euler[1]; gamma=euler[2]; d=self.Wd(J,beta); + D=np.diag(np.exp(-1j*alpha*np.arange(-J,J+1)))@ d @np.diag(np.exp(-1j*gamma*np.arange(-J,J+1))) + return D + + @staticmethod + def Wd(J,beta): + # compute Wigner small d matrix + d=np.zeros((2*J+1,2*J+1)); + + for m in range(-J,J+1): + for n in range(-J,J+1): + smin=max(0,m-n); smax=min(J+m,J-n); + for s in range(smin,smax+1): + mul=np.sqrt(factorial(J+m))/factorial(J+m-s)*np.sqrt(factorial(J+n))/factorial(s)*np.sqrt(factorial(J-m))/factorial(n-m+s)\ + *np.sqrt(factorial(J-n))/factorial(J-n-s) + d[n+J,m+J]+=mul*(-1)**(n-m+s)*(np.cos(beta/2))**(2*J+m-n-2*s)*(np.sin(beta/2))**(n-m+2*s); + + return d + + @staticmethod + def mat_block(vecA,N,sz,IDX_upper,IDX_lower,idx_offdiag): + tmp=vecA.T.reshape(N*(N+1)//2,sz,sz).transpose(0,2,1) + AA=cp.zeros((N**2,sz,sz)); AA[IDX_upper]=tmp; AA[IDX_lower]=tmp[idx_offdiag].transpose(0,2,1) + return (AA.reshape(N,N,sz,sz).transpose(0,2,1,3)).reshape(N*sz,N*sz) + + @staticmethod + def psd_projection(B): + # compute the PSD part of a symmstric matrix + evals,evecs=cp.linalg.eigh((B+B.T)/2); evals=cp.maximum(evals,0) + return (evecs*evals) @ evecs.T + + @staticmethod + def transform_block(A,k,Pk=None): + if Pk is None: + dk=2*k+1; Pk=cp.eye(dk) + for m in range(k): + for l in range(k-m): + Pk[(m+2*l,m+2*l+1),:]=Pk[(m+2*l+1,m+2*l),:] + AT=Pk@A@Pk.T + return AT[:k,:k].T.reshape(-1), AT[k:,k:].T.reshape(-1) + + @staticmethod + def transform_back_block(A0,A1,k,Pk=None): + dk=2*k+1; A=cp.zeros((dk,dk)); A[:k,:k]=A0.reshape(k,k).T; A[k:,k:]=A1.reshape(k+1,k+1).T + if Pk is None: + Pk=cp.eye(dk) + for m in range(k): + for l in range(k-m): + Pk[(m+2*l,m+2*l+1),:]=Pk[(m+2*l+1,m+2*l),:] + return Pk.T@A@Pk From c78aa9d1e14c7e11132a0762a0af55172c99a2d0 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 27 Feb 2026 14:37:49 -0500 Subject: [PATCH 02/38] isort, black --- src/aspire/abinitio/commonline_nug.py | 1441 ++++++++++++++++--------- 1 file changed, 939 insertions(+), 502 deletions(-) diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index eb092e8604..97a7420d64 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -1,11 +1,11 @@ import logging import time -import numpy as np import cupy as cp -from scipy.special import sph_harm, binom, jacobi, factorial +import numpy as np from scipy.io import loadmat from scipy.spatial.transform import Rotation as spr +from scipy.special import binom, factorial, jacobi, sph_harm from aspire.abinitio import CLOrient3D from aspire.nufft import nufft @@ -27,10 +27,10 @@ def __init__( n_rad=None, n_theta=360, max_shift=0.15, - shift_step=1, + shift_step=1, mask=True, Lmax=12, - loss='l1', + loss="l1", T=36, max_iter=501, rho=0.05, @@ -71,9 +71,9 @@ def __init__( self.Ngrid = Ngrid self.Nstep_yI = Nstep_yI self.sym = symmetry - + def estimate_rotations(self): - sym_euler,S=self.Symmetry_Euler(self.sym) + sym_euler, S = self.Symmetry_Euler(self.sym) imgs = self.src.images[:] C = self.compute_coeff(imgs, self.loss, self.Lmax, T=self.T) X_est = self.admm_sym_J( @@ -90,7 +90,7 @@ def estimate_rotations(self): self.Nstep_yI, ) - R_est,Euler_est = self.euler_est(X_est[0],X_est[S-1],self.sym,self.n_img) + R_est, Euler_est = self.euler_est(X_est[0], X_est[S - 1], self.sym, self.n_img) self.rotations = R_est return R_est @@ -98,104 +98,116 @@ def estimate_rotations(self): ####################### # Compute Coeffs Step # ####################### - + def compute_coeff(self, Img, loss, Lmax, T): - # compute the coefficient matrix - N,L,_=Img.shape; n_theta=360 - angular_sampling = np.arange(0, 360, 1) - line_proj=np.zeros((L,n_theta,N)) - Img_pft=np.zeros((L,n_theta,N),dtype=complex) + # compute the coefficient matrix + N, L, _ = Img.shape + n_theta = 360 + angular_sampling = np.arange(0, 360, 1) + line_proj = np.zeros((L, n_theta, N)) + Img_pft = np.zeros((L, n_theta, N), dtype=complex) Img = Img.asnumpy() for n in range(N): - line_proj[:,:,n],Img_pft[:,:,n]=self.fast_radon_transform(Img[n], angular_sampling) + line_proj[:, :, n], Img_pft[:, :, n] = self.fast_radon_transform( + Img[n], angular_sampling + ) - dim_wave=len(wemd_embed(line_proj[:,0,0])) - WE=np.zeros((dim_wave,n_theta,N)) + dim_wave = len(wemd_embed(line_proj[:, 0, 0])) + WE = np.zeros((dim_wave, n_theta, N)) for i in range(N): for theta in range(n_theta): - WE[:,theta,i]=wemd_embed(line_proj[:,theta,i]) - - - def fij(alpha,gamma,i,j,loss): - if loss=='l1': - Ii_hat=Img_pft[:,:,i]; Ij_hat=Img_pft[:,:,j] - idxi=np.round((alpha-np.pi/2)*n_theta/2/np.pi)% n_theta - idxj=np.round((-gamma-np.pi/2)*n_theta/2/np.pi)% n_theta - - Si=Ii_hat[:,int(idxi)]; Sj=Ij_hat[:,int(idxj)] - return np.linalg.norm(Si-Sj,1) - - if loss=='wemd': - idxi=np.round((alpha-np.pi/2)*n_theta/2/np.pi)% n_theta - idxj=np.round((-gamma-np.pi/2)*n_theta/2/np.pi)% n_theta - - Si=WE[:,int(idxi),i]; Sj=WE[:,int(idxj),j] - return np.linalg.norm(Si-Sj,1) - - alpha_grid=np.arange(2*T)*np.pi/T - beta_grid=(2*np.arange(2*T)+1)*np.pi/4/T - gamma_grid=np.arange(2*T)*np.pi/T - - bT=np.zeros(2*T); - for l in range(2*T): - ss=0 + WE[:, theta, i] = wemd_embed(line_proj[:, theta, i]) + + def fij(alpha, gamma, i, j, loss): + if loss == "l1": + Ii_hat = Img_pft[:, :, i] + Ij_hat = Img_pft[:, :, j] + idxi = np.round((alpha - np.pi / 2) * n_theta / 2 / np.pi) % n_theta + idxj = np.round((-gamma - np.pi / 2) * n_theta / 2 / np.pi) % n_theta + + Si = Ii_hat[:, int(idxi)] + Sj = Ij_hat[:, int(idxj)] + return np.linalg.norm(Si - Sj, 1) + + if loss == "wemd": + idxi = np.round((alpha - np.pi / 2) * n_theta / 2 / np.pi) % n_theta + idxj = np.round((-gamma - np.pi / 2) * n_theta / 2 / np.pi) % n_theta + + Si = WE[:, int(idxi), i] + Sj = WE[:, int(idxj), j] + return np.linalg.norm(Si - Sj, 1) + + alpha_grid = np.arange(2 * T) * np.pi / T + beta_grid = (2 * np.arange(2 * T) + 1) * np.pi / 4 / T + gamma_grid = np.arange(2 * T) * np.pi / T + + bT = np.zeros(2 * T) + for l in range(2 * T): + ss = 0 for m in range(T): - ss=ss+np.sin(beta_grid[l]*(2*m+1))/(2*m+1) - bT[l]=2/T*np.sin(beta_grid[l])*ss - - BTK=[]; - for k in range(1,Lmax+1): - dk=2*k+1 - btk=np.zeros((dk,dk)) - for l in range(2*T): btk=btk+bT[l]*self.Wd(k,beta_grid[l]) + ss = ss + np.sin(beta_grid[l] * (2 * m + 1)) / (2 * m + 1) + bT[l] = 2 / T * np.sin(beta_grid[l]) * ss + + BTK = [] + for k in range(1, Lmax + 1): + dk = 2 * k + 1 + btk = np.zeros((dk, dk)) + for l in range(2 * T): + btk = btk + bT[l] * self.Wd(k, beta_grid[l]) BTK.append(btk.T) + def fijhat_k(k, F): + dk = 2 * k + 1 - def fijhat_k(k,F): - dk=2*k+1; - - exp_alpha_grid=np.zeros((2*T,dk),dtype=complex) - for m in range(-k,k+1): - exp_alpha_grid[:,m+k]=np.exp(1j*m*alpha_grid) + exp_alpha_grid = np.zeros((2 * T, dk), dtype=complex) + for m in range(-k, k + 1): + exp_alpha_grid[:, m + k] = np.exp(1j * m * alpha_grid) - exp_gamma_grid=np.zeros((2*T,dk),dtype=complex) - for m in range(-k,k+1): - exp_gamma_grid[:,m+k]=np.exp(1j*m*gamma_grid) + exp_gamma_grid = np.zeros((2 * T, dk), dtype=complex) + for m in range(-k, k + 1): + exp_gamma_grid[:, m + k] = np.exp(1j * m * gamma_grid) - S=(exp_alpha_grid.T@F@exp_gamma_grid).T - fhat=BTK[k-1]*S/4/T**2 + S = (exp_alpha_grid.T @ F @ exp_gamma_grid).T + fhat = BTK[k - 1] * S / 4 / T**2 return fhat - C=[] - for k in range(1,Lmax+1): - dk=2*k+1; C.append(np.zeros((N*dk,N*dk),dtype=complex)) + C = [] + for k in range(1, Lmax + 1): + dk = 2 * k + 1 + C.append(np.zeros((N * dk, N * dk), dtype=complex)) for i in range(N): - for j in range(i+1,N): - Fij=np.zeros((2*T,2*T)) - for j1 in range(2*T): - for j2 in range(2*T): - Fij[j1,j2]=fij(alpha_grid[j1],gamma_grid[j2],i,j,loss) - for k in range(1,Lmax+1): - dk=2*k+1; C[k-1][j*dk:(j+1)*dk,i*dk:(i+1)*dk]=fijhat_k(k,Fij)#*dk - for k in range(1,Lmax+1): - C[k-1]=C[k-1]+C[k-1].conj().T + for j in range(i + 1, N): + Fij = np.zeros((2 * T, 2 * T)) + for j1 in range(2 * T): + for j2 in range(2 * T): + Fij[j1, j2] = fij(alpha_grid[j1], gamma_grid[j2], i, j, loss) + for k in range(1, Lmax + 1): + dk = 2 * k + 1 + C[k - 1][j * dk : (j + 1) * dk, i * dk : (i + 1) * dk] = fijhat_k( + k, Fij + ) # *dk + for k in range(1, Lmax + 1): + C[k - 1] = C[k - 1] + C[k - 1].conj().T for i in range(N): - Fii=np.zeros((2*T,2*T)) - for j1 in range(2*T): - for j2 in range(2*T): - Fii[j1,j2]=fij(alpha_grid[j1],gamma_grid[j2],i,i,loss) - for k in range(1,Lmax+1): - dk=2*k+1 - C[k-1][i*dk:(i+1)*dk,i*dk:(i+1)*dk]=fijhat_k(k,Fii)#*dk - - - for k in range(1,Lmax+1): - [T,Tinv] = self.complex2real(k) - C[k-1] = np.real(np.kron(np.eye(N),Tinv) @ C[k-1]@np.kron(np.eye(N),T)) - C[k-1]=np.round(C[k-1],10) - - return C + Fii = np.zeros((2 * T, 2 * T)) + for j1 in range(2 * T): + for j2 in range(2 * T): + Fii[j1, j2] = fij(alpha_grid[j1], gamma_grid[j2], i, i, loss) + for k in range(1, Lmax + 1): + dk = 2 * k + 1 + C[k - 1][i * dk : (i + 1) * dk, i * dk : (i + 1) * dk] = fijhat_k( + k, Fii + ) # *dk + + for k in range(1, Lmax + 1): + [T, Tinv] = self.complex2real(k) + C[k - 1] = np.real( + np.kron(np.eye(N), Tinv) @ C[k - 1] @ np.kron(np.eye(N), T) + ) + C[k - 1] = np.round(C[k - 1], 10) + + return C @staticmethod def fast_radon_transform(array, angles, use_ramp=False): @@ -215,7 +227,7 @@ def fast_radon_transform(array, angles, use_ramp=False): ) pts = pts.astype(array.dtype) - #array = array.astype(np.float32) + # array = array.astype(np.float32) lines_f = nufft(array, pts).reshape((img_size, -1)) if img_size % 2 == 0: @@ -225,69 +237,87 @@ def fast_radon_transform(array, angles, use_ramp=False): freqs = np.abs(np.pi * y_idx) lines_f *= freqs[:, np.newaxis] - projections = np.real(xp.asnumpy(fft.centered_ifft(xp.asarray(lines_f), axis=0))) + projections = np.real( + xp.asnumpy(fft.centered_ifft(xp.asarray(lines_f), axis=0)) + ) return projections, lines_f - @staticmethod - def Wd(J,beta): + def Wd(J, beta): # compute Wigner small d matrix - d=np.zeros((2*J+1,2*J+1)); - - for m in range(-J,J+1): - for n in range(-J,J+1): - smin=max(0,m-n); smax=min(J+m,J-n); - for s in range(smin,smax+1): - mul=np.sqrt(factorial(J+m))/factorial(J+m-s)*np.sqrt(factorial(J+n))/factorial(s)*np.sqrt(factorial(J-m))/factorial(n-m+s)\ - *np.sqrt(factorial(J-n))/factorial(J-n-s) - d[n+J,m+J]+=mul*(-1)**(n-m+s)*(np.cos(beta/2))**(2*J+m-n-2*s)*(np.sin(beta/2))**(n-m+2*s); - - return d + d = np.zeros((2 * J + 1, 2 * J + 1)) + + for m in range(-J, J + 1): + for n in range(-J, J + 1): + smin = max(0, m - n) + smax = min(J + m, J - n) + for s in range(smin, smax + 1): + mul = ( + np.sqrt(factorial(J + m)) + / factorial(J + m - s) + * np.sqrt(factorial(J + n)) + / factorial(s) + * np.sqrt(factorial(J - m)) + / factorial(n - m + s) + * np.sqrt(factorial(J - n)) + / factorial(J - n - s) + ) + d[n + J, m + J] += ( + mul + * (-1) ** (n - m + s) + * (np.cos(beta / 2)) ** (2 * J + m - n - 2 * s) + * (np.sin(beta / 2)) ** (n - m + 2 * s) + ) + + return d @staticmethod def complex2real(ell): # compute transformation matrices that convert complex representations to real ones - diml=2*ell+1 - Tinv=np.zeros((diml,diml),dtype=complex) + diml = 2 * ell + 1 + Tinv = np.zeros((diml, diml), dtype=complex) for i in range(diml): - if iell: - Tinv[i,i]=(-1)**(i-ell)/np.sqrt(2); Tinv[i,diml-1-i]=1/np.sqrt(2) - - T=np.zeros((diml,diml),dtype=complex) + if i < ell: + Tinv[i, i] = 1j / np.sqrt(2) + Tinv[i, diml - 1 - i] = -1j * (-1) ** (i - ell) / np.sqrt(2) + if i == ell: + Tinv[i, i] = 1 + if i > ell: + Tinv[i, i] = (-1) ** (i - ell) / np.sqrt(2) + Tinv[i, diml - 1 - i] = 1 / np.sqrt(2) + + T = np.zeros((diml, diml), dtype=complex) for i in range(diml): - if iell: - T[i,i]=(-1)**(i-ell)/np.sqrt(2); T[i,diml-1-i]=1j*(-1)**(i-ell)/np.sqrt(2) - return T,Tinv - + if i < ell: + T[i, i] = -1j / np.sqrt(2) + T[i, diml - 1 - i] = 1 / np.sqrt(2) + if i == ell: + T[i, i] = 1 + if i > ell: + T[i, i] = (-1) ** (i - ell) / np.sqrt(2) + T[i, diml - 1 - i] = 1j * (-1) ** (i - ell) / np.sqrt(2) + return T, Tinv ############# # ADMM Step # ############# def admm_sym_J( - self, - sym, - C, - Lmax, - N, - Ngrid, - max_iter, - rho, - ratio, - factor, - mult=1, - Nstep_yI=20, - verbose=True, - GPU=True, + self, + sym, + C, + Lmax, + N, + Ngrid, + max_iter, + rho, + ratio, + factor, + mult=1, + Nstep_yI=20, + verbose=True, + GPU=True, ): # admm for symmetric case if GPU: @@ -295,236 +325,463 @@ def admm_sym_J( else: import numpy as cp - C0,C1,normC,AEq,bEq,AEqAEqtinv,AI_mat_diag,AI_mat_offdiag,bI,Lambda,d0,d1,D0,D1,idx_diag,idx_offdiag,IDX_upper,IDX_lower,X0,X1,Xq,S0,S1,Sq=self.ADMM_preprocessing(C,Lmax,N,Ngrid,GPU) + ( + C0, + C1, + normC, + AEq, + bEq, + AEqAEqtinv, + AI_mat_diag, + AI_mat_offdiag, + bI, + Lambda, + d0, + d1, + D0, + D1, + idx_diag, + idx_offdiag, + IDX_upper, + IDX_lower, + X0, + X1, + Xq, + S0, + S1, + Sq, + ) = self.ADMM_preprocessing(C, Lmax, N, Ngrid, GPU) # S=int(sym[1]); sym_euler=Symmetry_Euler(sym) - rank_Ak,_=self.compute_rank(sym,Lmax); print(rank_Ak) + rank_Ak, _ = self.compute_rank(sym, Lmax) + print(rank_Ak) # rank_Ak=cp.zeros(Lmax) # for ell in range(Lmax): rank_Ak[ell]=np.linalg.matrix_rank(Ak(ell+1,sym_euler)) - AE=[]; AEAETinv=[] - for k in range(1,Lmax+1): - s0=k**2; s1=(k+1)**2 - AEk=cp.zeros((1+s0+s1,2*(s0+s1))) - AEk[0,:s0]=cp.eye(k).T.reshape(-1); AEk[0,s0:s0+s1]=cp.eye(k+1).T.reshape(-1) - for count in range(1,1+s0): - AEk[count,count-1]=1; AEk[count,count-1+s0+s1]=1 - for count in range(1+s0,1+s0+s1): - AEk[count,count-1]=1; AEk[count,count-1+s1+s0]=1 + AE = [] + AEAETinv = [] + for k in range(1, Lmax + 1): + s0 = k**2 + s1 = (k + 1) ** 2 + AEk = cp.zeros((1 + s0 + s1, 2 * (s0 + s1))) + AEk[0, :s0] = cp.eye(k).T.reshape(-1) + AEk[0, s0 : s0 + s1] = cp.eye(k + 1).T.reshape(-1) + for count in range(1, 1 + s0): + AEk[count, count - 1] = 1 + AEk[count, count - 1 + s0 + s1] = 1 + for count in range(1 + s0, 1 + s0 + s1): + AEk[count, count - 1] = 1 + AEk[count, count - 1 + s1 + s0] = 1 AE.append(AEk) - AEAETinv.append(np.linalg.pinv(AEk@AEk.T)) - bE=cp.zeros((Lmax+D0+D1)) + AEAETinv.append(np.linalg.pinv(AEk @ AEk.T)) + bE = cp.zeros((Lmax + D0 + D1)) for k in range(Lmax): - bE[k+d0[k]+d1[k]:]=rank_Ak[k] - bE[k+1+d0[k]+d1[k]:k+1+d0[k+1]+d1[k]]=cp.eye(k+1).T.reshape(-1) - bE[k+1+d0[k+1]+d1[k]:k+1+d0[k+1]+d1[k+1]]=cp.eye(k+2).T.reshape(-1) - bE=np.repeat(bE[:,np.newaxis],N,axis=1) - P=[]; - for k in range(1,Lmax+1): - dk=2*k+1; Pk=cp.eye(dk) + bE[k + d0[k] + d1[k] :] = rank_Ak[k] + bE[k + 1 + d0[k] + d1[k] : k + 1 + d0[k + 1] + d1[k]] = cp.eye( + k + 1 + ).T.reshape(-1) + bE[k + 1 + d0[k + 1] + d1[k] : k + 1 + d0[k + 1] + d1[k + 1]] = cp.eye( + k + 2 + ).T.reshape(-1) + bE = np.repeat(bE[:, np.newaxis], N, axis=1) + P = [] + for k in range(1, Lmax + 1): + dk = 2 * k + 1 + Pk = cp.eye(dk) for m in range(k): - for l in range(k-m): - Pk[(m+2*l,m+2*l+1),:]=Pk[(m+2*l+1,m+2*l),:] - P.append(Pk) - - - def fun_AE(X0,X1,Xd0,Xd1,Xq): - z=cp.zeros((Lmax+D0+D1,N)) + for l in range(k - m): + Pk[(m + 2 * l, m + 2 * l + 1), :] = Pk[ + (m + 2 * l + 1, m + 2 * l), : + ] + P.append(Pk) + + def fun_AE(X0, X1, Xd0, Xd1, Xq): + z = cp.zeros((Lmax + D0 + D1, N)) for k in range(Lmax): - z[k+d0[k]+d1[k]:k+1+d0[k+1]+d1[k+1]]=AE[k]@ cp.concatenate((X0[d0[k]:d0[k+1],idx_diag],X1[d1[k]:d1[k+1],idx_diag],Xd0[d0[k]:d0[k+1]],Xd1[d1[k]:d1[k+1]]),axis=0) - return z, AEq@cp.concatenate((Xq,X0[:1,idx_offdiag],X1[:4,idx_offdiag]),axis=0) - - def fun_AET(yE,yEq): - Z0=cp.zeros((D0,N*(N+1)//2)); Z1=cp.zeros((D1,N*(N+1)//2)); - Zd0=cp.zeros((D0,N)); Zd1=cp.zeros((D1,N)) + z[k + d0[k] + d1[k] : k + 1 + d0[k + 1] + d1[k + 1]] = AE[ + k + ] @ cp.concatenate( + ( + X0[d0[k] : d0[k + 1], idx_diag], + X1[d1[k] : d1[k + 1], idx_diag], + Xd0[d0[k] : d0[k + 1]], + Xd1[d1[k] : d1[k + 1]], + ), + axis=0, + ) + return z, AEq @ cp.concatenate( + (Xq, X0[:1, idx_offdiag], X1[:4, idx_offdiag]), axis=0 + ) + + def fun_AET(yE, yEq): + Z0 = cp.zeros((D0, N * (N + 1) // 2)) + Z1 = cp.zeros((D1, N * (N + 1) // 2)) + Zd0 = cp.zeros((D0, N)) + Zd1 = cp.zeros((D1, N)) for k in range(Lmax): - s0=(k+1)**2; s1=(k+2)**2 - Ztmp=AE[k].T@yE[k+d0[k]+d1[k]:k+1+d0[k+1]+d1[k+1]] - Z0[d0[k]:d0[k+1],idx_diag]=Ztmp[:s0] - Z1[d1[k]:d1[k+1],idx_diag]=Ztmp[s0:s0+s1] - Zd0[d0[k]:d0[k+1]]=Ztmp[s0+s1:2*s0+s1] - Zd1[d1[k]:d1[k+1]]=Ztmp[2*s0+s1:2*s0+2*s1] - Zq=AEq.T@yEq - Z0[:1,idx_offdiag]=Zq[16:17]; Z1[:4,idx_offdiag]=Zq[17:] - return Z0,Z1,Zd0,Zd1,Zq[:16] - - def fun_AI(X0,X1): - z=cp.zeros((Ngrid,N*(N+1)//2)); tmp=cp.concatenate((X0,X1),axis=0); - z[:,idx_diag]=AI_mat_diag@tmp[:,idx_diag]; z[:,idx_offdiag]=AI_mat_offdiag@tmp[:,idx_offdiag] + s0 = (k + 1) ** 2 + s1 = (k + 2) ** 2 + Ztmp = AE[k].T @ yE[k + d0[k] + d1[k] : k + 1 + d0[k + 1] + d1[k + 1]] + Z0[d0[k] : d0[k + 1], idx_diag] = Ztmp[:s0] + Z1[d1[k] : d1[k + 1], idx_diag] = Ztmp[s0 : s0 + s1] + Zd0[d0[k] : d0[k + 1]] = Ztmp[s0 + s1 : 2 * s0 + s1] + Zd1[d1[k] : d1[k + 1]] = Ztmp[2 * s0 + s1 : 2 * s0 + 2 * s1] + Zq = AEq.T @ yEq + Z0[:1, idx_offdiag] = Zq[16:17] + Z1[:4, idx_offdiag] = Zq[17:] + return Z0, Z1, Zd0, Zd1, Zq[:16] + + def fun_AI(X0, X1): + z = cp.zeros((Ngrid, N * (N + 1) // 2)) + tmp = cp.concatenate((X0, X1), axis=0) + z[:, idx_diag] = AI_mat_diag @ tmp[:, idx_diag] + z[:, idx_offdiag] = AI_mat_offdiag @ tmp[:, idx_offdiag] # return AI_mat@cp.concatenate((X0,X1),axis=0) return z def fun_AIT(yI): # Z=AI_mat.T@yI - # return Z[:D0,:], Z[D0:,:] - Z=cp.zeros((D0+D1,N*(N+1)//2)); Z[:,idx_diag]=AI_mat_diag.T@yI[:,idx_diag]; Z[:,idx_offdiag]=AI_mat_offdiag.T@yI[:,idx_offdiag]; - return Z[:D0,:], Z[D0:,:] - - def update_S(C0,C1,yE,yEq,yI,X0,X1,Xd0,Xd1,Xq,rho,Lmax,N): - Z0,Z1,Zd0,Zd1,Zq=fun_AET(yE,yEq); AIT_yI0,AIT_yI1=fun_AIT(yI) - S0=C0-Z0-AIT_yI0-X0/rho; S1=C1-Z1-AIT_yI1-X1/rho; - Sd0=-Zd0-Xd0/rho; Sd1=-Zd1-Xd1/rho; Sq=-Zq-Xq/rho - tic0=time.perf_counter(); - for k in range(1,Lmax+1): - tmp=self.mat_block(S0[d0[k-1]:d0[k],:],N,k,IDX_upper,IDX_lower,idx_offdiag) - tmp=self.psd_projection(tmp) - S0[d0[k-1]:d0[k],:]=self.vec_block(tmp,N,k,IDX_upper) - - tmp=self.mat_block(S1[d1[k-1]:d1[k],:],N,k+1,IDX_upper,IDX_lower,idx_offdiag) - tmp=self.psd_projection(tmp) - S1[d1[k-1]:d1[k],:]=self.vec_block(tmp,N,k+1,IDX_upper) - toc0=time.perf_counter(); Time[0]+=toc0-tic0 - tic1=time.perf_counter(); - for k in range(1,Lmax+1): + # return Z[:D0,:], Z[D0:,:] + Z = cp.zeros((D0 + D1, N * (N + 1) // 2)) + Z[:, idx_diag] = AI_mat_diag.T @ yI[:, idx_diag] + Z[:, idx_offdiag] = AI_mat_offdiag.T @ yI[:, idx_offdiag] + return Z[:D0, :], Z[D0:, :] + + def update_S(C0, C1, yE, yEq, yI, X0, X1, Xd0, Xd1, Xq, rho, Lmax, N): + Z0, Z1, Zd0, Zd1, Zq = fun_AET(yE, yEq) + AIT_yI0, AIT_yI1 = fun_AIT(yI) + S0 = C0 - Z0 - AIT_yI0 - X0 / rho + S1 = C1 - Z1 - AIT_yI1 - X1 / rho + Sd0 = -Zd0 - Xd0 / rho + Sd1 = -Zd1 - Xd1 / rho + Sq = -Zq - Xq / rho + tic0 = time.perf_counter() + for k in range(1, Lmax + 1): + tmp = self.mat_block( + S0[d0[k - 1] : d0[k], :], N, k, IDX_upper, IDX_lower, idx_offdiag + ) + tmp = self.psd_projection(tmp) + S0[d0[k - 1] : d0[k], :] = self.vec_block(tmp, N, k, IDX_upper) + + tmp = self.mat_block( + S1[d1[k - 1] : d1[k], :], + N, + k + 1, + IDX_upper, + IDX_lower, + idx_offdiag, + ) + tmp = self.psd_projection(tmp) + S1[d1[k - 1] : d1[k], :] = self.vec_block(tmp, N, k + 1, IDX_upper) + toc0 = time.perf_counter() + Time[0] += toc0 - tic0 + tic1 = time.perf_counter() + for k in range(1, Lmax + 1): for n in range(N): - tmp=self.transform_back_block(Sd0[d0[k-1]:d0[k],n],Sd1[d1[k-1]:d1[k],n],k,P[k-1]) - tmp=self.psd_projection(tmp) - Sd0[d0[k-1]:d0[k],n],Sd1[d1[k-1]:d1[k],n]=self.transform_block(tmp,k,P[k-1]) - toc1=time.perf_counter(); Time[1]+=toc1-tic1 - tic2=time.perf_counter(); - for count in range(N*(N-1)//2): - tmp=self.psd_projection(Sq[:,count].reshape(4,4).T) - Sq[:,count]=tmp.T.reshape(16) - toc2=time.perf_counter(); Time[2]+=toc2-tic2 - return S0,S1,Sd0,Sd1,Sq - - - def update_yE(C0,C1,X0,X1,Xd0,Xd1,Xq,S0,S1,Sd0,Sd1,Sq,yI,rho): - AIT_yI0,AIT_yI1=fun_AIT(yI) - z,zq=fun_AE(-X0/rho+C0-S0-AIT_yI0,-X1/rho+C1-S1-AIT_yI1,-Xd0/rho-Sd0,-Xd1/rho-Sd1,-Xq/rho-Sq) - yE=cp.zeros((Lmax+D0+D1,N)) + tmp = self.transform_back_block( + Sd0[d0[k - 1] : d0[k], n], + Sd1[d1[k - 1] : d1[k], n], + k, + P[k - 1], + ) + tmp = self.psd_projection(tmp) + Sd0[d0[k - 1] : d0[k], n], Sd1[d1[k - 1] : d1[k], n] = ( + self.transform_block(tmp, k, P[k - 1]) + ) + toc1 = time.perf_counter() + Time[1] += toc1 - tic1 + tic2 = time.perf_counter() + for count in range(N * (N - 1) // 2): + tmp = self.psd_projection(Sq[:, count].reshape(4, 4).T) + Sq[:, count] = tmp.T.reshape(16) + toc2 = time.perf_counter() + Time[2] += toc2 - tic2 + return S0, S1, Sd0, Sd1, Sq + + def update_yE(C0, C1, X0, X1, Xd0, Xd1, Xq, S0, S1, Sd0, Sd1, Sq, yI, rho): + AIT_yI0, AIT_yI1 = fun_AIT(yI) + z, zq = fun_AE( + -X0 / rho + C0 - S0 - AIT_yI0, + -X1 / rho + C1 - S1 - AIT_yI1, + -Xd0 / rho - Sd0, + -Xd1 / rho - Sd1, + -Xq / rho - Sq, + ) + yE = cp.zeros((Lmax + D0 + D1, N)) for k in range(Lmax): - yE[k+d0[k]+d1[k]:k+1+d0[k+1]+d1[k+1]]=AEAETinv[k]@(bE[k+d0[k]+d1[k]:k+1+d0[k+1]+d1[k+1]]/rho+z[k+d0[k]+d1[k]:k+1+d0[k+1]+d1[k+1]]) - yEq=AEqAEqtinv@(bEq/rho+zq) - return yE,yEq - - - def update_yI(C0,C1,X0,X1,S0,S1,yE,yEq,yI,rho,Lambda): - Z0,Z1,_,_,_=fun_AET(yE,yEq); AIT_yI0,AIT_yI1=fun_AIT(yI) - tmp=fun_AI(-X0/rho+C0-S0-Z0-AIT_yI0,-X1/rho+C1-S1-Z1-AIT_yI1) - yI=yI+bI/rho/Lambda+tmp/Lambda - yI=np.maximum(yI,0) + yE[k + d0[k] + d1[k] : k + 1 + d0[k + 1] + d1[k + 1]] = AEAETinv[k] @ ( + bE[k + d0[k] + d1[k] : k + 1 + d0[k + 1] + d1[k + 1]] / rho + + z[k + d0[k] + d1[k] : k + 1 + d0[k + 1] + d1[k + 1]] + ) + yEq = AEqAEqtinv @ (bEq / rho + zq) + return yE, yEq + + def update_yI(C0, C1, X0, X1, S0, S1, yE, yEq, yI, rho, Lambda): + Z0, Z1, _, _, _ = fun_AET(yE, yEq) + AIT_yI0, AIT_yI1 = fun_AIT(yI) + tmp = fun_AI( + -X0 / rho + C0 - S0 - Z0 - AIT_yI0, -X1 / rho + C1 - S1 - Z1 - AIT_yI1 + ) + yI = yI + bI / rho / Lambda + tmp / Lambda + yI = np.maximum(yI, 0) return yI - def update_X(C0,C1,X0,X1,Xd0,Xd1,Xq,yE,yEq,yI,S0,S1,Sd0,Sd1,Sq,rho): - Z0,Z1,Zd0,Zd1,Zq=fun_AET(yE,yEq); AIT_yI0,AIT_yI1=fun_AIT(yI) - tmp=S0+Z0+AIT_yI0-C0; X0=X0+mult*rho*tmp; resX0=np.linalg.norm(tmp) - tmp=S1+Z1+AIT_yI1-C1; X1=X1+mult*rho*tmp; resX1=np.linalg.norm(tmp) - tmp=Sd0+Zd0; Xd0=Xd0+mult*rho*tmp; resXd0=np.linalg.norm(tmp) - tmp=Sd1+Zd1; Xd1=Xd1+mult*rho*tmp; resXd1=np.linalg.norm(tmp) - tmp=Sq+Zq; Xq=Xq+mult*rho*tmp; resXq=np.linalg.norm(tmp) - return X0,X1,Xd0,Xd1,Xq,np.sqrt(resX0**2+resX1**2+resXd0**2+resXd1**2+resXq**2) - - def update_rho(X0,X1,Xd0,Xd1,Xq,bE,bEq,bI,res_X,rho,factor,normC): - z,zq=fun_AE(X0,X1,Xd0,Xd1,Xq) - res_eq=np.linalg.norm(z-bE)/(1+np.linalg.norm(bE))+np.linalg.norm(zq-bEq)/(1+np.linalg.norm(bEq)); - res_inq=np.linalg.norm(np.maximum(bI-fun_AI(X0,X1),0))/(1+abs(bI)*np.sqrt(Ngrid*N**2)); - p_resnorm=res_eq+res_inq - d_resnorm=res_X/(1+normC) - if d_resnorm>ratio*p_resnorm: rho=rho*factor - if d_resnorm ratio * p_resnorm: + rho = rho * factor + if d_resnorm < ratio * p_resnorm: + rho = rho / factor + return rho, p_resnorm, d_resnorm def print_updates(verbose=True): # X_admm=transform_coeff_back(X0,X1,Lmax,N); obj_p=0 # for k in range(Lmax): obj_p+=cp.trace(C[k]@X_admm[k]) - if verbose: - obj_p=cp.vdot(C0[:,idx_diag],X0[:,idx_diag])+cp.vdot(C1[:,idx_diag],X1[:,idx_diag])+2*cp.vdot(C0[:,idx_offdiag],X0[:,idx_offdiag])+2*cp.vdot(C1[:,idx_offdiag],X1[:,idx_offdiag]) - obj_d=cp.vdot(yE,bE)+2*cp.vdot(yEq,bEq)+cp.vdot(yI[:,idx_diag],bI*cp.ones((Ngrid,N)))+2*cp.vdot(yI[:,idx_offdiag],bI*cp.ones((Ngrid,N*(N-1)//2))); - - - z,zq=fun_AE(X0,X1,Xd0,Xd1,Xq); - res_eq=np.linalg.norm(z-bE)/(1+np.linalg.norm(bE))+np.linalg.norm(zq-bEq)/(1+np.linalg.norm(bEq)); - res_inq=np.linalg.norm(np.maximum(bI-fun_AI(X0,X1),0))/(1+abs(bI)*np.sqrt(Ngrid*N*(N+1)/2)); - res_psdX=0 - for k in range(1,Lmax+1): - tmp=self.mat_block(X0[d0[k-1]:d0[k],:],N,k,IDX_upper,IDX_lower,idx_offdiag); res_psdX+=np.linalg.norm(self.psd_projection(-tmp)); #res_psdX+=norm(self.psd_projection(-tmp))/(1+norm(tmp)) - tmp=self.mat_block(X1[d1[k-1]:d1[k],:],N,k+1,IDX_upper,IDX_lower,idx_offdiag); res_psdX+=np.linalg.norm(self.psd_projection(-tmp)); #res_psdX+=norm(self.psd_projection(-tmp))/(1+norm(tmp)) - res_psdX=res_psdX/(1+np.linalg.norm(X0)+np.linalg.norm(X1)) - res_psdD=0 - for k in range(1,Lmax+1): + if verbose: + obj_p = ( + cp.vdot(C0[:, idx_diag], X0[:, idx_diag]) + + cp.vdot(C1[:, idx_diag], X1[:, idx_diag]) + + 2 * cp.vdot(C0[:, idx_offdiag], X0[:, idx_offdiag]) + + 2 * cp.vdot(C1[:, idx_offdiag], X1[:, idx_offdiag]) + ) + obj_d = ( + cp.vdot(yE, bE) + + 2 * cp.vdot(yEq, bEq) + + cp.vdot(yI[:, idx_diag], bI * cp.ones((Ngrid, N))) + + 2 + * cp.vdot( + yI[:, idx_offdiag], bI * cp.ones((Ngrid, N * (N - 1) // 2)) + ) + ) + + z, zq = fun_AE(X0, X1, Xd0, Xd1, Xq) + res_eq = np.linalg.norm(z - bE) / ( + 1 + np.linalg.norm(bE) + ) + np.linalg.norm(zq - bEq) / (1 + np.linalg.norm(bEq)) + res_inq = np.linalg.norm(np.maximum(bI - fun_AI(X0, X1), 0)) / ( + 1 + abs(bI) * np.sqrt(Ngrid * N * (N + 1) / 2) + ) + res_psdX = 0 + for k in range(1, Lmax + 1): + tmp = self.mat_block( + X0[d0[k - 1] : d0[k], :], + N, + k, + IDX_upper, + IDX_lower, + idx_offdiag, + ) + res_psdX += np.linalg.norm(self.psd_projection(-tmp)) + # res_psdX+=norm(self.psd_projection(-tmp))/(1+norm(tmp)) + tmp = self.mat_block( + X1[d1[k - 1] : d1[k], :], + N, + k + 1, + IDX_upper, + IDX_lower, + idx_offdiag, + ) + res_psdX += np.linalg.norm(self.psd_projection(-tmp)) + # res_psdX+=norm(self.psd_projection(-tmp))/(1+norm(tmp)) + res_psdX = res_psdX / (1 + np.linalg.norm(X0) + np.linalg.norm(X1)) + res_psdD = 0 + for k in range(1, Lmax + 1): for n in range(N): - tmp=self.transform_back_block(Xd0[d0[k-1]:d0[k],n],Xd1[d1[k-1]:d1[k],n],k,P[k-1]) - res_psdD+=np.linalg.norm(self.psd_projection(-tmp)); #res_psdD+=norm(self.psd_projection(-tmp))/(1+norm(tmp)) - res_psdD=res_psdD/(1+np.linalg.norm(Xd0)+np.linalg.norm(Xd1)) - res_psdQ=0 - for count in range(N*(N-1)//2): - tmp=Xq[:,count].reshape(4,4).T; res_psdQ+=np.linalg.norm(self.psd_projection(-tmp)); #res_psdQ+=norm(self.psd_projection(-tmp))/(1+norm(tmp)) - res_psdQ=res_psdQ/(1+np.linalg.norm(Xq)) - - - normS=np.sqrt(np.linalg.norm(S0)**2+np.linalg.norm(S1)**2+np.linalg.norm(Sd0)**2+np.linalg.norm(Sd1)**2+np.linalg.norm(Sq)**2) - normX=np.sqrt(np.linalg.norm(X0)**2+np.linalg.norm(X1)**2+np.linalg.norm(Xd0)**2+np.linalg.norm(Xd1)**2+np.linalg.norm(Xq)**2) - p_res=res_eq+res_inq+res_psdX+res_psdD+res_psdQ - d_res=res_X/(1+normC) - print('Iter %i'%t+': p_res=%1.5f'%p_res+', d_res=%1.5f'%d_res+', obj_primal=%1.2f'%obj_p+', obj_dual=%1.2f'%obj_d+', duality gap=%1.2f'%(obj_p-obj_d)+ - '\n eq_res=%1.5f'%res_eq+', inq_res=%1.5f'%res_inq+', psd_res=%1.5f'%(res_psdX+res_psdD+res_psdQ)+', |S|=%1.2f'%normS+', |X|=%1.2f'%normX) - - - Xd0=cp.zeros((D0,N)); Xd1=cp.zeros((D1,N)); Sd0=cp.zeros(Xd0.shape); Sd1=cp.zeros(Xd1.shape); - yI=cp.zeros((Ngrid,N*(N+1)//2)); yE=cp.zeros(bE.shape); yEq=cp.zeros(bEq.shape) - - IDX=np.arange(3); Time=np.zeros(4) + tmp = self.transform_back_block( + Xd0[d0[k - 1] : d0[k], n], + Xd1[d1[k - 1] : d1[k], n], + k, + P[k - 1], + ) + res_psdD += np.linalg.norm(self.psd_projection(-tmp)) + # res_psdD+=norm(self.psd_projection(-tmp))/(1+norm(tmp)) + res_psdD = res_psdD / (1 + np.linalg.norm(Xd0) + np.linalg.norm(Xd1)) + res_psdQ = 0 + for count in range(N * (N - 1) // 2): + tmp = Xq[:, count].reshape(4, 4).T + res_psdQ += np.linalg.norm(self.psd_projection(-tmp)) + # res_psdQ+=norm(self.psd_projection(-tmp))/(1+norm(tmp)) + res_psdQ = res_psdQ / (1 + np.linalg.norm(Xq)) + + normS = np.sqrt( + np.linalg.norm(S0) ** 2 + + np.linalg.norm(S1) ** 2 + + np.linalg.norm(Sd0) ** 2 + + np.linalg.norm(Sd1) ** 2 + + np.linalg.norm(Sq) ** 2 + ) + normX = np.sqrt( + np.linalg.norm(X0) ** 2 + + np.linalg.norm(X1) ** 2 + + np.linalg.norm(Xd0) ** 2 + + np.linalg.norm(Xd1) ** 2 + + np.linalg.norm(Xq) ** 2 + ) + p_res = res_eq + res_inq + res_psdX + res_psdD + res_psdQ + d_res = res_X / (1 + normC) + print( + "Iter %i" % t + + ": p_res=%1.5f" % p_res + + ", d_res=%1.5f" % d_res + + ", obj_primal=%1.2f" % obj_p + + ", obj_dual=%1.2f" % obj_d + + ", duality gap=%1.2f" % (obj_p - obj_d) + + "\n eq_res=%1.5f" % res_eq + + ", inq_res=%1.5f" % res_inq + + ", psd_res=%1.5f" % (res_psdX + res_psdD + res_psdQ) + + ", |S|=%1.2f" % normS + + ", |X|=%1.2f" % normX + ) + + Xd0 = cp.zeros((D0, N)) + Xd1 = cp.zeros((D1, N)) + Sd0 = cp.zeros(Xd0.shape) + Sd1 = cp.zeros(Xd1.shape) + yI = cp.zeros((Ngrid, N * (N + 1) // 2)) + yE = cp.zeros(bE.shape) + yEq = cp.zeros(bEq.shape) + + IDX = np.arange(3) + Time = np.zeros(4) for t in range(max_iter): np.random.shuffle(IDX) for idx in IDX: - if idx==0: S0,S1,Sd0,Sd1,Sq=update_S(C0,C1,yE,yEq,yI,X0,X1,Xd0,Xd1,Xq,rho,Lmax,N); - if idx==1: yE,yEq=update_yE(C0,C1,X0,X1,Xd0,Xd1,Xq,S0,S1,Sd0,Sd1,Sq,yI,rho); - if idx==2: + if idx == 0: + S0, S1, Sd0, Sd1, Sq = update_S( + C0, C1, yE, yEq, yI, X0, X1, Xd0, Xd1, Xq, rho, Lmax, N + ) + if idx == 1: + yE, yEq = update_yE( + C0, C1, X0, X1, Xd0, Xd1, Xq, S0, S1, Sd0, Sd1, Sq, yI, rho + ) + if idx == 2: for kk in range(Nstep_yI): - yI=update_yI(C0,C1,X0,X1,S0,S1,yE,yEq,yI,rho,Lambda) - X0,X1,Xd0,Xd1,Xq,res_X=update_X(C0,C1,X0,X1,Xd0,Xd1,Xq,yE,yEq,yI,S0,S1,Sd0,Sd1,Sq,rho); - if t%100==0: print_updates(verbose) - rho,p_resnorm,d_resnorm=update_rho(X0,X1,Xd0,Xd1,Xq,bE,bEq,bI,res_X,rho,factor,normC) - - X_admm=self.transform_coeff_back(X0,X1,Lmax,N,IDX_upper,IDX_lower,idx_offdiag) + yI = update_yI(C0, C1, X0, X1, S0, S1, yE, yEq, yI, rho, Lambda) + X0, X1, Xd0, Xd1, Xq, res_X = update_X( + C0, C1, X0, X1, Xd0, Xd1, Xq, yE, yEq, yI, S0, S1, Sd0, Sd1, Sq, rho + ) + if t % 100 == 0: + print_updates(verbose) + rho, p_resnorm, d_resnorm = update_rho( + X0, X1, Xd0, Xd1, Xq, bE, bEq, bI, res_X, rho, factor, normC + ) + + X_admm = self.transform_coeff_back( + X0, X1, Lmax, N, IDX_upper, IDX_lower, idx_offdiag + ) if GPU: - for k in range(Lmax): X_admm[k]=X_admm[k].get() - return X_admm - + for k in range(Lmax): + X_admm[k] = X_admm[k].get() + return X_admm - def ADMM_preprocessing(self, C,Lmax,N,Ngrid,GPU=True): + def ADMM_preprocessing(self, C, Lmax, N, Ngrid, GPU=True): # compute necessary quantities for ADMM - if GPU: import cupy as cp - else: import numpy as cp + if GPU: + import cupy as cp + else: + import numpy as cp - # compute some useful index sets - count=0; idx_diag=[]; idx_offdiag=[]; + # compute some useful index sets + count = 0 + idx_diag = [] + idx_offdiag = [] for i in range(N): - for j in range(i,N): - if j==i: idx_diag.append(count) - else: idx_offdiag.append(count) - count+=1 - IDX_upper=[]; IDX_lower=[]; + for j in range(i, N): + if j == i: + idx_diag.append(count) + else: + idx_offdiag.append(count) + count += 1 + IDX_upper = [] + IDX_lower = [] for i in range(N): for j in range(N): - if j>=i: IDX_upper.append(i*N+j); + if j >= i: + IDX_upper.append(i * N + j) for j in range(N): for i in range(N): - if j500: - Lambda=cp.linalg.norm(z); z=z/cp.linalg.norm(z); z=AI@(AI.T@z) - Lambda+=2000; print('Largest eigenvalue of AIAIT is approximately %1.2f'%Lambda) + z = cp.random.normal(0, 1, (Ngrid, N**2)) + Lambda = 0 + + while abs(Lambda - cp.linalg.norm(z)) > 500: + Lambda = cp.linalg.norm(z) + z = z / cp.linalg.norm(z) + z = AI @ (AI.T @ z) + Lambda += 2000 + print("Largest eigenvalue of AIAIT is approximately %1.2f" % Lambda) # Lambda=cp.linalg.eigvalsh(AI@AI.T)[-1]; print(Lambda) - return Lambda - - def compute_rank(self, sym,Lmax): - sym_euler,S=self.Symmetry_Euler(sym); rk=cp.zeros(Lmax); A=[] - for k in range(1,Lmax+1): - dk=2*k+1; Ak=np.zeros((dk,dk),dtype=np.complex128) - for s in range(S): Ak+=self.WD(k,sym_euler[s]) - Ak=np.round(Ak/S,6); A.append(Ak); rk[k-1]=np.linalg.matrix_rank(Ak) - return rk,A + return Lambda + + def compute_rank(self, sym, Lmax): + sym_euler, S = self.Symmetry_Euler(sym) + rk = cp.zeros(Lmax) + A = [] + for k in range(1, Lmax + 1): + dk = 2 * k + 1 + Ak = np.zeros((dk, dk), dtype=np.complex128) + for s in range(S): + Ak += self.WD(k, sym_euler[s]) + Ak = np.round(Ak / S, 6) + A.append(Ak) + rk[k - 1] = np.linalg.matrix_rank(Ak) + return rk, A @staticmethod def Symmetry_Euler(sym): - if sym[0]=='C': - order=int(sym[1:]); sym_euler=np.zeros((order,3)) - for i in range(order): sym_euler[i]=spr.from_euler('zyx',[2*np.pi/order*i,0,0]).as_euler('zyz') - - if sym[0]=='D': - order=int(sym[1:]); sym_euler=np.zeros((2*order,3)) - for i in range(order): - sym_euler[i]=spr.from_euler('zyx',[2*np.pi/order*i,0,0]).as_euler('zyz') - sym_euler[i+order]=spr.from_euler('zyx',[2*np.pi/order*i,0,np.pi]).as_euler('zyz') - - if sym=='T12': - sym_euler=np.zeros((12,3)) - sym_euler[1]=spr.from_rotvec(2*np.pi/3*np.array([0,0,1])).as_euler('zyz') - sym_euler[2]=spr.from_rotvec(4*np.pi/3*np.array([0,0,1])).as_euler('zyz') - - - return sym_euler,sym_euler.shape[0] - - def WD(self, J,euler): + if sym[0] == "C": + order = int(sym[1:]) + sym_euler = np.zeros((order, 3)) + for i in range(order): + sym_euler[i] = spr.from_euler( + "zyx", [2 * np.pi / order * i, 0, 0] + ).as_euler("zyz") + + if sym[0] == "D": + order = int(sym[1:]) + sym_euler = np.zeros((2 * order, 3)) + for i in range(order): + sym_euler[i] = spr.from_euler( + "zyx", [2 * np.pi / order * i, 0, 0] + ).as_euler("zyz") + sym_euler[i + order] = spr.from_euler( + "zyx", [2 * np.pi / order * i, 0, np.pi] + ).as_euler("zyz") + + if sym == "T12": + sym_euler = np.zeros((12, 3)) + sym_euler[1] = spr.from_rotvec( + 2 * np.pi / 3 * np.array([0, 0, 1]) + ).as_euler("zyz") + sym_euler[2] = spr.from_rotvec( + 4 * np.pi / 3 * np.array([0, 0, 1]) + ).as_euler("zyz") + + return sym_euler, sym_euler.shape[0] + + def WD(self, J, euler): # compute Wigner D matrix - alpha=euler[0]; beta=euler[1]; gamma=euler[2]; d=self.Wd(J,beta); - D=np.diag(np.exp(-1j*alpha*np.arange(-J,J+1)))@ d @np.diag(np.exp(-1j*gamma*np.arange(-J,J+1))) + alpha = euler[0] + beta = euler[1] + gamma = euler[2] + d = self.Wd(J, beta) + D = ( + np.diag(np.exp(-1j * alpha * np.arange(-J, J + 1))) + @ d + @ np.diag(np.exp(-1j * gamma * np.arange(-J, J + 1))) + ) return D @staticmethod - def Wd(J,beta): + def Wd(J, beta): # compute Wigner small d matrix - d=np.zeros((2*J+1,2*J+1)); - - for m in range(-J,J+1): - for n in range(-J,J+1): - smin=max(0,m-n); smax=min(J+m,J-n); - for s in range(smin,smax+1): - mul=np.sqrt(factorial(J+m))/factorial(J+m-s)*np.sqrt(factorial(J+n))/factorial(s)*np.sqrt(factorial(J-m))/factorial(n-m+s)\ - *np.sqrt(factorial(J-n))/factorial(J-n-s) - d[n+J,m+J]+=mul*(-1)**(n-m+s)*(np.cos(beta/2))**(2*J+m-n-2*s)*(np.sin(beta/2))**(n-m+2*s); - - return d + d = np.zeros((2 * J + 1, 2 * J + 1)) + + for m in range(-J, J + 1): + for n in range(-J, J + 1): + smin = max(0, m - n) + smax = min(J + m, J - n) + for s in range(smin, smax + 1): + mul = ( + np.sqrt(factorial(J + m)) + / factorial(J + m - s) + * np.sqrt(factorial(J + n)) + / factorial(s) + * np.sqrt(factorial(J - m)) + / factorial(n - m + s) + * np.sqrt(factorial(J - n)) + / factorial(J - n - s) + ) + d[n + J, m + J] += ( + mul + * (-1) ** (n - m + s) + * (np.cos(beta / 2)) ** (2 * J + m - n - 2 * s) + * (np.sin(beta / 2)) ** (n - m + 2 * s) + ) + + return d @staticmethod - def mat_block(vecA,N,sz,IDX_upper,IDX_lower,idx_offdiag): - tmp=vecA.T.reshape(N*(N+1)//2,sz,sz).transpose(0,2,1) - AA=cp.zeros((N**2,sz,sz)); AA[IDX_upper]=tmp; AA[IDX_lower]=tmp[idx_offdiag].transpose(0,2,1) - return (AA.reshape(N,N,sz,sz).transpose(0,2,1,3)).reshape(N*sz,N*sz) + def mat_block(vecA, N, sz, IDX_upper, IDX_lower, idx_offdiag): + tmp = vecA.T.reshape(N * (N + 1) // 2, sz, sz).transpose(0, 2, 1) + AA = cp.zeros((N**2, sz, sz)) + AA[IDX_upper] = tmp + AA[IDX_lower] = tmp[idx_offdiag].transpose(0, 2, 1) + return (AA.reshape(N, N, sz, sz).transpose(0, 2, 1, 3)).reshape(N * sz, N * sz) @staticmethod def psd_projection(B): # compute the PSD part of a symmstric matrix - evals,evecs=cp.linalg.eigh((B+B.T)/2); evals=cp.maximum(evals,0) - return (evecs*evals) @ evecs.T + evals, evecs = cp.linalg.eigh((B + B.T) / 2) + evals = cp.maximum(evals, 0) + return (evecs * evals) @ evecs.T @staticmethod - def transform_block(A,k,Pk=None): + def transform_block(A, k, Pk=None): if Pk is None: - dk=2*k+1; Pk=cp.eye(dk) + dk = 2 * k + 1 + Pk = cp.eye(dk) for m in range(k): - for l in range(k-m): - Pk[(m+2*l,m+2*l+1),:]=Pk[(m+2*l+1,m+2*l),:] - AT=Pk@A@Pk.T - return AT[:k,:k].T.reshape(-1), AT[k:,k:].T.reshape(-1) - + for l in range(k - m): + Pk[(m + 2 * l, m + 2 * l + 1), :] = Pk[ + (m + 2 * l + 1, m + 2 * l), : + ] + AT = Pk @ A @ Pk.T + return AT[:k, :k].T.reshape(-1), AT[k:, k:].T.reshape(-1) + @staticmethod - def transform_back_block(A0,A1,k,Pk=None): - dk=2*k+1; A=cp.zeros((dk,dk)); A[:k,:k]=A0.reshape(k,k).T; A[k:,k:]=A1.reshape(k+1,k+1).T + def transform_back_block(A0, A1, k, Pk=None): + dk = 2 * k + 1 + A = cp.zeros((dk, dk)) + A[:k, :k] = A0.reshape(k, k).T + A[k:, k:] = A1.reshape(k + 1, k + 1).T if Pk is None: - Pk=cp.eye(dk) + Pk = cp.eye(dk) for m in range(k): - for l in range(k-m): - Pk[(m+2*l,m+2*l+1),:]=Pk[(m+2*l+1,m+2*l),:] - return Pk.T@A@Pk + for l in range(k - m): + Pk[(m + 2 * l, m + 2 * l + 1), :] = Pk[ + (m + 2 * l + 1, m + 2 * l), : + ] + return Pk.T @ A @ Pk From 45e555b665df11df9496da0b9e656ee18c02c389 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 6 Mar 2026 11:26:41 -0500 Subject: [PATCH 03/38] Use numeric to handle cupy/numpy. Remove duplicate function. --- src/aspire/abinitio/commonline_nug.py | 255 +++++++++++--------------- 1 file changed, 110 insertions(+), 145 deletions(-) diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index 97a7420d64..c2f9aa249c 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -5,7 +5,7 @@ import numpy as np from scipy.io import loadmat from scipy.spatial.transform import Rotation as spr -from scipy.special import binom, factorial, jacobi, sph_harm +from scipy.special import factorial from aspire.abinitio import CLOrient3D from aspire.nufft import nufft @@ -73,6 +73,7 @@ def __init__( self.sym = symmetry def estimate_rotations(self): + breakpoint() sym_euler, S = self.Symmetry_Euler(self.sym) imgs = self.src.images[:] C = self.compute_coeff(imgs, self.loss, self.Lmax, T=self.T) @@ -142,18 +143,18 @@ def fij(alpha, gamma, i, j, loss): gamma_grid = np.arange(2 * T) * np.pi / T bT = np.zeros(2 * T) - for l in range(2 * T): + for n in range(2 * T): ss = 0 for m in range(T): - ss = ss + np.sin(beta_grid[l] * (2 * m + 1)) / (2 * m + 1) - bT[l] = 2 / T * np.sin(beta_grid[l]) * ss + ss = ss + np.sin(beta_grid[n] * (2 * m + 1)) / (2 * m + 1) + bT[n] = 2 / T * np.sin(beta_grid[n]) * ss BTK = [] for k in range(1, Lmax + 1): dk = 2 * k + 1 btk = np.zeros((dk, dk)) - for l in range(2 * T): - btk = btk + bT[l] * self.Wd(k, beta_grid[l]) + for n in range(2 * T): + btk = btk + bT[n] * self.Wd(k, beta_grid[n]) BTK.append(btk.T) def fijhat_k(k, F): @@ -227,7 +228,7 @@ def fast_radon_transform(array, angles, use_ramp=False): ) pts = pts.astype(array.dtype) - # array = array.astype(np.float32) + #array = array.astype(np.float32) lines_f = nufft(array, pts).reshape((img_size, -1)) if img_size % 2 == 0: @@ -243,35 +244,6 @@ def fast_radon_transform(array, angles, use_ramp=False): return projections, lines_f - @staticmethod - def Wd(J, beta): - # compute Wigner small d matrix - d = np.zeros((2 * J + 1, 2 * J + 1)) - - for m in range(-J, J + 1): - for n in range(-J, J + 1): - smin = max(0, m - n) - smax = min(J + m, J - n) - for s in range(smin, smax + 1): - mul = ( - np.sqrt(factorial(J + m)) - / factorial(J + m - s) - * np.sqrt(factorial(J + n)) - / factorial(s) - * np.sqrt(factorial(J - m)) - / factorial(n - m + s) - * np.sqrt(factorial(J - n)) - / factorial(J - n - s) - ) - d[n + J, m + J] += ( - mul - * (-1) ** (n - m + s) - * (np.cos(beta / 2)) ** (2 * J + m - n - 2 * s) - * (np.sin(beta / 2)) ** (n - m + 2 * s) - ) - - return d - @staticmethod def complex2real(ell): # compute transformation matrices that convert complex representations to real ones @@ -317,14 +289,8 @@ def admm_sym_J( mult=1, Nstep_yI=20, verbose=True, - GPU=True, ): # admm for symmetric case - if GPU: - import cupy as cp - else: - import numpy as cp - ( C0, C1, @@ -350,11 +316,11 @@ def admm_sym_J( S0, S1, Sq, - ) = self.ADMM_preprocessing(C, Lmax, N, Ngrid, GPU) + ) = self.ADMM_preprocessing(C, Lmax, N, Ngrid) # S=int(sym[1]); sym_euler=Symmetry_Euler(sym) rank_Ak, _ = self.compute_rank(sym, Lmax) - print(rank_Ak) + logger.info(f"Rank of Ak: {rank_Ak}") # rank_Ak=cp.zeros(Lmax) # for ell in range(Lmax): rank_Ak[ell]=np.linalg.matrix_rank(Ak(ell+1,sym_euler)) @@ -364,9 +330,9 @@ def admm_sym_J( for k in range(1, Lmax + 1): s0 = k**2 s1 = (k + 1) ** 2 - AEk = cp.zeros((1 + s0 + s1, 2 * (s0 + s1))) - AEk[0, :s0] = cp.eye(k).T.reshape(-1) - AEk[0, s0 : s0 + s1] = cp.eye(k + 1).T.reshape(-1) + AEk = xp.zeros((1 + s0 + s1, 2 * (s0 + s1))) + AEk[0, :s0] = xp.eye(k).T.reshape(-1) + AEk[0, s0 : s0 + s1] = xp.eye(k + 1).T.reshape(-1) for count in range(1, 1 + s0): AEk[count, count - 1] = 1 AEk[count, count - 1 + s0 + s1] = 1 @@ -375,20 +341,20 @@ def admm_sym_J( AEk[count, count - 1 + s1 + s0] = 1 AE.append(AEk) AEAETinv.append(np.linalg.pinv(AEk @ AEk.T)) - bE = cp.zeros((Lmax + D0 + D1)) + bE = xp.zeros((Lmax + D0 + D1)) for k in range(Lmax): bE[k + d0[k] + d1[k] :] = rank_Ak[k] - bE[k + 1 + d0[k] + d1[k] : k + 1 + d0[k + 1] + d1[k]] = cp.eye( + bE[k + 1 + d0[k] + d1[k] : k + 1 + d0[k + 1] + d1[k]] = xp.eye( k + 1 ).T.reshape(-1) - bE[k + 1 + d0[k + 1] + d1[k] : k + 1 + d0[k + 1] + d1[k + 1]] = cp.eye( + bE[k + 1 + d0[k + 1] + d1[k] : k + 1 + d0[k + 1] + d1[k + 1]] = xp.eye( k + 2 ).T.reshape(-1) bE = np.repeat(bE[:, np.newaxis], N, axis=1) P = [] for k in range(1, Lmax + 1): dk = 2 * k + 1 - Pk = cp.eye(dk) + Pk = xp.eye(dk) for m in range(k): for l in range(k - m): Pk[(m + 2 * l, m + 2 * l + 1), :] = Pk[ @@ -397,11 +363,11 @@ def admm_sym_J( P.append(Pk) def fun_AE(X0, X1, Xd0, Xd1, Xq): - z = cp.zeros((Lmax + D0 + D1, N)) + z = xp.zeros((Lmax + D0 + D1, N)) for k in range(Lmax): z[k + d0[k] + d1[k] : k + 1 + d0[k + 1] + d1[k + 1]] = AE[ k - ] @ cp.concatenate( + ] @ xp.concatenate( ( X0[d0[k] : d0[k + 1], idx_diag], X1[d1[k] : d1[k + 1], idx_diag], @@ -410,15 +376,15 @@ def fun_AE(X0, X1, Xd0, Xd1, Xq): ), axis=0, ) - return z, AEq @ cp.concatenate( + return z, AEq @ xp.concatenate( (Xq, X0[:1, idx_offdiag], X1[:4, idx_offdiag]), axis=0 ) def fun_AET(yE, yEq): - Z0 = cp.zeros((D0, N * (N + 1) // 2)) - Z1 = cp.zeros((D1, N * (N + 1) // 2)) - Zd0 = cp.zeros((D0, N)) - Zd1 = cp.zeros((D1, N)) + Z0 = xp.zeros((D0, N * (N + 1) // 2)) + Z1 = xp.zeros((D1, N * (N + 1) // 2)) + Zd0 = xp.zeros((D0, N)) + Zd1 = xp.zeros((D1, N)) for k in range(Lmax): s0 = (k + 1) ** 2 s1 = (k + 2) ** 2 @@ -433,17 +399,17 @@ def fun_AET(yE, yEq): return Z0, Z1, Zd0, Zd1, Zq[:16] def fun_AI(X0, X1): - z = cp.zeros((Ngrid, N * (N + 1) // 2)) - tmp = cp.concatenate((X0, X1), axis=0) + z = xp.zeros((Ngrid, N * (N + 1) // 2)) + tmp = xp.concatenate((X0, X1), axis=0) z[:, idx_diag] = AI_mat_diag @ tmp[:, idx_diag] z[:, idx_offdiag] = AI_mat_offdiag @ tmp[:, idx_offdiag] - # return AI_mat@cp.concatenate((X0,X1),axis=0) + # return AI_mat@xp.concatenate((X0,X1),axis=0) return z def fun_AIT(yI): # Z=AI_mat.T@yI # return Z[:D0,:], Z[D0:,:] - Z = cp.zeros((D0 + D1, N * (N + 1) // 2)) + Z = xp.zeros((D0 + D1, N * (N + 1) // 2)) Z[:, idx_diag] = AI_mat_diag.T @ yI[:, idx_diag] Z[:, idx_offdiag] = AI_mat_offdiag.T @ yI[:, idx_offdiag] return Z[:D0, :], Z[D0:, :] @@ -508,7 +474,7 @@ def update_yE(C0, C1, X0, X1, Xd0, Xd1, Xq, S0, S1, Sd0, Sd1, Sq, yI, rho): -Xd1 / rho - Sd1, -Xq / rho - Sq, ) - yE = cp.zeros((Lmax + D0 + D1, N)) + yE = xp.zeros((Lmax + D0 + D1, N)) for k in range(Lmax): yE[k + d0[k] + d1[k] : k + 1 + d0[k + 1] + d1[k + 1]] = AEAETinv[k] @ ( bE[k + d0[k] + d1[k] : k + 1 + d0[k + 1] + d1[k + 1]] / rho @@ -574,21 +540,21 @@ def update_rho(X0, X1, Xd0, Xd1, Xq, bE, bEq, bI, res_X, rho, factor, normC): def print_updates(verbose=True): # X_admm=transform_coeff_back(X0,X1,Lmax,N); obj_p=0 - # for k in range(Lmax): obj_p+=cp.trace(C[k]@X_admm[k]) + # for k in range(Lmax): obj_p+=xp.trace(C[k]@X_admm[k]) if verbose: obj_p = ( - cp.vdot(C0[:, idx_diag], X0[:, idx_diag]) - + cp.vdot(C1[:, idx_diag], X1[:, idx_diag]) - + 2 * cp.vdot(C0[:, idx_offdiag], X0[:, idx_offdiag]) - + 2 * cp.vdot(C1[:, idx_offdiag], X1[:, idx_offdiag]) + xp.vdot(C0[:, idx_diag], X0[:, idx_diag]) + + xp.vdot(C1[:, idx_diag], X1[:, idx_diag]) + + 2 * xp.vdot(C0[:, idx_offdiag], X0[:, idx_offdiag]) + + 2 * xp.vdot(C1[:, idx_offdiag], X1[:, idx_offdiag]) ) obj_d = ( - cp.vdot(yE, bE) - + 2 * cp.vdot(yEq, bEq) - + cp.vdot(yI[:, idx_diag], bI * cp.ones((Ngrid, N))) + xp.vdot(yE, bE) + + 2 * xp.vdot(yEq, bEq) + + xp.vdot(yI[:, idx_diag], bI * xp.ones((Ngrid, N))) + 2 - * cp.vdot( - yI[:, idx_offdiag], bI * cp.ones((Ngrid, N * (N - 1) // 2)) + * xp.vdot( + yI[:, idx_offdiag], bI * xp.ones((Ngrid, N * (N - 1) // 2)) ) ) @@ -657,7 +623,7 @@ def print_updates(verbose=True): ) p_res = res_eq + res_inq + res_psdX + res_psdD + res_psdQ d_res = res_X / (1 + normC) - print( + logger.info( "Iter %i" % t + ": p_res=%1.5f" % p_res + ", d_res=%1.5f" % d_res @@ -671,13 +637,13 @@ def print_updates(verbose=True): + ", |X|=%1.2f" % normX ) - Xd0 = cp.zeros((D0, N)) - Xd1 = cp.zeros((D1, N)) - Sd0 = cp.zeros(Xd0.shape) - Sd1 = cp.zeros(Xd1.shape) - yI = cp.zeros((Ngrid, N * (N + 1) // 2)) - yE = cp.zeros(bE.shape) - yEq = cp.zeros(bEq.shape) + Xd0 = xp.zeros((D0, N)) + Xd1 = xp.zeros((D1, N)) + Sd0 = xp.zeros(Xd0.shape) + Sd1 = xp.zeros(Xd1.shape) + yI = xp.zeros((Ngrid, N * (N + 1) // 2)) + yE = xp.zeros(bE.shape) + yEq = xp.zeros(bEq.shape) IDX = np.arange(3) Time = np.zeros(4) @@ -693,7 +659,7 @@ def print_updates(verbose=True): C0, C1, X0, X1, Xd0, Xd1, Xq, S0, S1, Sd0, Sd1, Sq, yI, rho ) if idx == 2: - for kk in range(Nstep_yI): + for _ in range(Nstep_yI): yI = update_yI(C0, C1, X0, X1, S0, S1, yE, yEq, yI, rho, Lambda) X0, X1, Xd0, Xd1, Xq, res_X = update_X( C0, C1, X0, X1, Xd0, Xd1, Xq, yE, yEq, yI, S0, S1, Sd0, Sd1, Sq, rho @@ -707,18 +673,16 @@ def print_updates(verbose=True): X_admm = self.transform_coeff_back( X0, X1, Lmax, N, IDX_upper, IDX_lower, idx_offdiag ) - if GPU: - for k in range(Lmax): - X_admm[k] = X_admm[k].get() + # if self.GPU: + # for k in range(Lmax): + # X_admm[k] = X_admm[k].get() + for k in range(Lmax): + X_admm[k] = xp.asnumpy(X_admm[k]) return X_admm - def ADMM_preprocessing(self, C, Lmax, N, Ngrid, GPU=True): + def ADMM_preprocessing(self, C, Lmax, N, Ngrid): # compute necessary quantities for ADMM - if GPU: - import cupy as cp - else: - import numpy as cp - + # compute some useful index sets count = 0 idx_diag = [] @@ -751,7 +715,7 @@ def ADMM_preprocessing(self, C, Lmax, N, Ngrid, GPU=True): Cnorm = np.sqrt(Cnorm) Xnorm = np.sqrt(Xnorm) for k in range(Lmax): - C[k] = cp.asarray(Xnorm / Cnorm * C[k]) + C[k] = xp.asarray(Xnorm / Cnorm * C[k]) C0, C1 = self.transform_coeff(C, Lmax, N, IDX_upper) normC = np.sqrt(np.linalg.norm(C0) ** 2 + np.linalg.norm(C1) ** 2) del C @@ -766,12 +730,12 @@ def ADMM_preprocessing(self, C, Lmax, N, Ngrid, GPU=True): D1 = d1[-1] # AE and bE for quaternion constraints - AEq = cp.asarray(loadmat("data/Eq_constraints/AEqJ.mat")["AEq"]) - AEqAEqtinv = cp.asarray(loadmat("data/Eq_constraints/AEqJ.mat")["AEqAEqtinv"]) - bEq = cp.zeros(17) - bEq[:16] = cp.eye(4).reshape(-1) / 4 + AEq = xp.asarray(loadmat("data/Eq_constraints/AEqJ.mat")["AEq"]) + AEqAEqtinv = xp.asarray(loadmat("data/Eq_constraints/AEqJ.mat")["AEqAEqtinv"]) + bEq = xp.zeros(17) + bEq[:16] = xp.eye(4).reshape(-1) / 4 bEq[-1] = 1 - bEq = cp.repeat(bEq[:, cp.newaxis], N * (N - 1) // 2, axis=1) + bEq = xp.repeat(bEq[:, xp.newaxis], N * (N - 1) // 2, axis=1) # AI and bI W0 = [ @@ -790,7 +754,7 @@ def ADMM_preprocessing(self, C, Lmax, N, Ngrid, GPU=True): # w1[d1[k-1]:d1[k]]=(Lmax-k+2)*(Lmax-k+1)*(k+0.5)*W1[k-1][p].T.reshape(-1) # # this needs double checking # AI_mat[p,:d0[-1]]=w0; AI_mat[p,d0[-1]:]=w1 - # AI_mat=cp.asarray(AI_mat) / 10; + # AI_mat=xp.asarray(AI_mat) / 10; # bI=-(Lmax+2)*(Lmax+1)/2 / 10 AI_mat_offdiag = np.zeros((Ngrid, D0 + D1)) @@ -833,10 +797,10 @@ def ADMM_preprocessing(self, C, Lmax, N, Ngrid, GPU=True): # this needs double checking AI_mat_diag[p, : d0[-1]] = w0 AI_mat_diag[p, d0[-1] :] = w1 - AI_mat_diag = cp.asarray(AI_mat_diag) / 1 - AI_mat_offdiag = cp.asarray(AI_mat_offdiag) / 1 + AI_mat_diag = xp.asarray(AI_mat_diag) / 1 + AI_mat_offdiag = xp.asarray(AI_mat_offdiag) / 1 bI = -(Lmax + 2) * (Lmax + 1) / 2 / 1 - + np.save("AI_mat_offdiag_aspire.npy", AI_mat_offdiag) # largest eigenvalue for AIAIT Lambda = self.largest_eigenvalue(AI_mat_offdiag, Ngrid, N) @@ -844,16 +808,16 @@ def ADMM_preprocessing(self, C, Lmax, N, Ngrid, GPU=True): II = [] for k in range(1, Lmax + 1): dk = 2 * k + 1 - II.append(cp.eye(N * dk)) + II.append(xp.eye(N * dk)) I0, I1 = self.transform_coeff(II, Lmax, N, IDX_upper) - X0 = cp.zeros((D0, N * (N + 1) // 2)) - X1 = cp.zeros((D1, N * (N + 1) // 2)) - Xq = cp.zeros((16, N * (N - 1) // 2)) - # X0,X1=transform_coeff(II,Lmax,N); Xq=cp.zeros((16,N*(N-1))); Xq[0,:]=1; - S0 = cp.copy(I0) - S1 = cp.copy(I1) - Sq = cp.zeros(Xq.shape) - # S0=cp.zeros(X0.shape); S1=cp.zeros(X1.shape); Sq=cp.zeros(Xq.shape) + X0 = xp.zeros((D0, N * (N + 1) // 2)) + X1 = xp.zeros((D1, N * (N + 1) // 2)) + Xq = xp.zeros((16, N * (N - 1) // 2)) + # X0,X1=transform_coeff(II,Lmax,N); Xq=xp.zeros((16,N*(N-1))); Xq[0,:]=1; + S0 = xp.copy(I0) + S1 = xp.copy(I1) + Sq = xp.zeros(Xq.shape) + # S0=xp.zeros(X0.shape); S1=xp.zeros(X1.shape); Sq=xp.zeros(Xq.shape) # return C0,C1,normC,AEq,bEq,AEqAEqtinv,AI_mat,bI,Lambda,d0,d1,D0,D1,idx_diag,idx_offdiag,IDX_upper,IDX_lower,X0,X1,Xq,S0,S1,Sq return ( @@ -1007,8 +971,8 @@ def transform_coeff(self, A, Lmax, N, IDX_upper): for k in range(1, Lmax + 1): d0.append(d0[-1] + k**2) d1.append(d1[-1] + (k + 1) ** 2) - A0 = cp.zeros((d0[-1], N * (N + 1) // 2)) - A1 = cp.zeros((d1[-1], N * (N + 1) // 2)) + A0 = xp.zeros((d0[-1], N * (N + 1) // 2)) + A1 = xp.zeros((d1[-1], N * (N + 1) // 2)) for k in range(1, Lmax + 1): a0, a1 = self.permutek(A[k - 1], k, N) A0[d0[k - 1] : d0[k], :] = self.vec_block(a0, N, k, IDX_upper) @@ -1024,7 +988,7 @@ def transform_coeff_back(self, A0, A1, Lmax, N, IDX_upper, IDX_lower, idx_offdia A = [] for k in range(1, Lmax + 1): dk = 2 * k + 1 - Ak = cp.zeros((N * dk, N * dk)) + Ak = xp.zeros((N * dk, N * dk)) Ak[: N * k, : N * k] = self.mat_block( A0[d0[k - 1] : d0[k], :], N, k, IDX_upper, IDX_lower, idx_offdiag ) @@ -1037,20 +1001,20 @@ def transform_coeff_back(self, A0, A1, Lmax, N, IDX_upper, IDX_lower, idx_offdia @staticmethod def permutek(Ak, k, N): - AkP = cp.copy(Ak) + AkP = xp.copy(Ak) dk = 2 * k + 1 - Pk = cp.eye(dk) + Pk = xp.eye(dk) for m in range(k): - for l in range(k - m): - Pk[(m + 2 * l, m + 2 * l + 1), :] = Pk[(m + 2 * l + 1, m + 2 * l), :] - AkP = cp.kron(cp.eye(N), Pk) @ Ak @ cp.kron(cp.eye(N), Pk.T) + for n in range(k - m): + Pk[(m + 2 * n, m + 2 * n + 1), :] = Pk[(m + 2 * n + 1, m + 2 * n), :] + AkP = xp.kron(xp.eye(N), Pk) @ Ak @ xp.kron(xp.eye(N), Pk.T) - Pk = cp.eye(N * dk) - idx = cp.concatenate((cp.arange(dk - k, dk), cp.arange(k + 1))) + Pk = xp.eye(N * dk) + idx = xp.concatenate((xp.arange(dk - k, dk), xp.arange(k + 1))) for m in range(N - 1): - for l in range(N - 1 - m): - Pk[k * (m + 1) + l * dk : k * (m + 1) + (l + 1) * dk] = Pk[ - k * (m + 1) + l * dk : k * (m + 1) + (l + 1) * dk + for n in range(N - 1 - m): + Pk[k * (m + 1) + n * dk : k * (m + 1) + (n + 1) * dk] = Pk[ + k * (m + 1) + n * dk : k * (m + 1) + (n + 1) * dk ][idx, :] AkP = Pk @ AkP @ Pk.T return AkP[: N * k, : N * k], AkP[N * k :, N * k :] @@ -1058,20 +1022,20 @@ def permutek(Ak, k, N): @staticmethod def permutek_back(Ak, k, N): dk = 2 * k + 1 - Pk = cp.eye(N * dk) - idx = cp.concatenate((cp.arange(dk - k, dk), cp.arange(k + 1))) + Pk = xp.eye(N * dk) + idx = xp.concatenate((xp.arange(dk - k, dk), xp.arange(k + 1))) for m in range(N - 1): - for l in range(N - 1 - m): - Pk[k * (m + 1) + l * dk : k * (m + 1) + (l + 1) * dk] = Pk[ - k * (m + 1) + l * dk : k * (m + 1) + (l + 1) * dk + for n in range(N - 1 - m): + Pk[k * (m + 1) + n * dk : k * (m + 1) + (n + 1) * dk] = Pk[ + k * (m + 1) + n * dk : k * (m + 1) + (n + 1) * dk ][idx, :] AkB = Pk.T @ Ak @ Pk dk = 2 * k + 1 - Pk = cp.eye(dk) + Pk = xp.eye(dk) for m in range(k): - for l in range(k - m): - Pk[(m + 2 * l, m + 2 * l + 1), :] = Pk[(m + 2 * l + 1, m + 2 * l), :] - AkB = cp.kron(cp.eye(N), Pk.T) @ AkB @ cp.kron(cp.eye(N), Pk) + for n in range(k - m): + Pk[(m + 2 * n, m + 2 * n + 1), :] = Pk[(m + 2 * n + 1, m + 2 * n), :] + AkB = xp.kron(xp.eye(N), Pk.T) @ AkB @ xp.kron(xp.eye(N), Pk) return AkB @staticmethod @@ -1082,21 +1046,22 @@ def vec_block(A, N, sz, IDX_upper): @staticmethod def largest_eigenvalue(AI, Ngrid, N): # find the largest eigenvalue of the operator AI - z = cp.random.normal(0, 1, (Ngrid, N**2)) + np.random.seed(0) + z = xp.random.normal(0, 1, (Ngrid, N**2)) Lambda = 0 - while abs(Lambda - cp.linalg.norm(z)) > 500: - Lambda = cp.linalg.norm(z) - z = z / cp.linalg.norm(z) + while abs(Lambda - xp.linalg.norm(z)) > 500: + Lambda = xp.linalg.norm(z) + z = z / xp.linalg.norm(z) z = AI @ (AI.T @ z) Lambda += 2000 - print("Largest eigenvalue of AIAIT is approximately %1.2f" % Lambda) - # Lambda=cp.linalg.eigvalsh(AI@AI.T)[-1]; print(Lambda) + logger.info("Largest eigenvalue of AIAIT is approximately %1.2f" % Lambda) + # Lambda=xp.linalg.eigvalsh(AI@AI.T)[-1]; print(Lambda) return Lambda def compute_rank(self, sym, Lmax): sym_euler, S = self.Symmetry_Euler(sym) - rk = cp.zeros(Lmax) + rk = xp.zeros(Lmax) A = [] for k in range(1, Lmax + 1): dk = 2 * k + 1 @@ -1185,7 +1150,7 @@ def Wd(J, beta): @staticmethod def mat_block(vecA, N, sz, IDX_upper, IDX_lower, idx_offdiag): tmp = vecA.T.reshape(N * (N + 1) // 2, sz, sz).transpose(0, 2, 1) - AA = cp.zeros((N**2, sz, sz)) + AA = xp.zeros((N**2, sz, sz)) AA[IDX_upper] = tmp AA[IDX_lower] = tmp[idx_offdiag].transpose(0, 2, 1) return (AA.reshape(N, N, sz, sz).transpose(0, 2, 1, 3)).reshape(N * sz, N * sz) @@ -1193,15 +1158,15 @@ def mat_block(vecA, N, sz, IDX_upper, IDX_lower, idx_offdiag): @staticmethod def psd_projection(B): # compute the PSD part of a symmstric matrix - evals, evecs = cp.linalg.eigh((B + B.T) / 2) - evals = cp.maximum(evals, 0) + evals, evecs = xp.linalg.eigh((B + B.T) / 2) + evals = xp.maximum(evals, 0) return (evecs * evals) @ evecs.T @staticmethod def transform_block(A, k, Pk=None): if Pk is None: dk = 2 * k + 1 - Pk = cp.eye(dk) + Pk = xp.eye(dk) for m in range(k): for l in range(k - m): Pk[(m + 2 * l, m + 2 * l + 1), :] = Pk[ @@ -1213,11 +1178,11 @@ def transform_block(A, k, Pk=None): @staticmethod def transform_back_block(A0, A1, k, Pk=None): dk = 2 * k + 1 - A = cp.zeros((dk, dk)) + A = xp.zeros((dk, dk)) A[:k, :k] = A0.reshape(k, k).T A[k:, k:] = A1.reshape(k + 1, k + 1).T if Pk is None: - Pk = cp.eye(dk) + Pk = xp.eye(dk) for m in range(k): for l in range(k - m): Pk[(m + 2 * l, m + 2 * l + 1), :] = Pk[ From ce139825a955562216c3a035e9c6e48c1ed9ec2d Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 9 Mar 2026 10:22:27 -0400 Subject: [PATCH 04/38] tox --- src/aspire/abinitio/commonline_nug.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index c2f9aa249c..410ebb6fc8 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -1,7 +1,6 @@ import logging import time -import cupy as cp import numpy as np from scipy.io import loadmat from scipy.spatial.transform import Rotation as spr @@ -73,7 +72,6 @@ def __init__( self.sym = symmetry def estimate_rotations(self): - breakpoint() sym_euler, S = self.Symmetry_Euler(self.sym) imgs = self.src.images[:] C = self.compute_coeff(imgs, self.loss, self.Lmax, T=self.T) @@ -228,7 +226,7 @@ def fast_radon_transform(array, angles, use_ramp=False): ) pts = pts.astype(array.dtype) - #array = array.astype(np.float32) + # array = array.astype(np.float32) lines_f = nufft(array, pts).reshape((img_size, -1)) if img_size % 2 == 0: @@ -356,9 +354,9 @@ def admm_sym_J( dk = 2 * k + 1 Pk = xp.eye(dk) for m in range(k): - for l in range(k - m): - Pk[(m + 2 * l, m + 2 * l + 1), :] = Pk[ - (m + 2 * l + 1, m + 2 * l), : + for el in range(k - m): + Pk[(m + 2 * el, m + 2 * el + 1), :] = Pk[ + (m + 2 * el + 1, m + 2 * el), : ] P.append(Pk) @@ -682,7 +680,6 @@ def print_updates(verbose=True): def ADMM_preprocessing(self, C, Lmax, N, Ngrid): # compute necessary quantities for ADMM - # compute some useful index sets count = 0 idx_diag = [] @@ -1168,9 +1165,9 @@ def transform_block(A, k, Pk=None): dk = 2 * k + 1 Pk = xp.eye(dk) for m in range(k): - for l in range(k - m): - Pk[(m + 2 * l, m + 2 * l + 1), :] = Pk[ - (m + 2 * l + 1, m + 2 * l), : + for el in range(k - m): + Pk[(m + 2 * el, m + 2 * el + 1), :] = Pk[ + (m + 2 * el + 1, m + 2 * el), : ] AT = Pk @ A @ Pk.T return AT[:k, :k].T.reshape(-1), AT[k:, k:].T.reshape(-1) @@ -1184,8 +1181,8 @@ def transform_back_block(A0, A1, k, Pk=None): if Pk is None: Pk = xp.eye(dk) for m in range(k): - for l in range(k - m): - Pk[(m + 2 * l, m + 2 * l + 1), :] = Pk[ - (m + 2 * l + 1, m + 2 * l), : + for el in range(k - m): + Pk[(m + 2 * el, m + 2 * el + 1), :] = Pk[ + (m + 2 * el + 1, m + 2 * el), : ] return Pk.T @ A @ Pk From c0a3bdfd34560c688299dbf811bca78b4108722c Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 12 Mar 2026 14:52:11 -0400 Subject: [PATCH 05/38] build full pf --- src/aspire/abinitio/commonline_nug.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index 410ebb6fc8..27d2e134b3 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -9,7 +9,7 @@ from aspire.abinitio import CLOrient3D from aspire.nufft import nufft from aspire.numeric import fft, xp -from aspire.operators import wemd_embed +from aspire.operators import PolarFT, wemd_embed logger = logging.getLogger(__name__) @@ -71,6 +71,12 @@ def __init__( self.Nstep_yI = Nstep_yI self.sym = symmetry + self._build_full_pft() + + def _build_full_pft(self): + pf = self.pf + self.pf_full = PolarFT.half_to_full(pf) + def estimate_rotations(self): sym_euler, S = self.Symmetry_Euler(self.sym) imgs = self.src.images[:] @@ -105,6 +111,8 @@ def compute_coeff(self, Img, loss, Lmax, T): angular_sampling = np.arange(0, 360, 1) line_proj = np.zeros((L, n_theta, N)) Img_pft = np.zeros((L, n_theta, N), dtype=complex) + + # Replace with Image.project() later Img = Img.asnumpy() for n in range(N): line_proj[:, :, n], Img_pft[:, :, n] = self.fast_radon_transform( @@ -126,6 +134,17 @@ def fij(alpha, gamma, i, j, loss): Si = Ii_hat[:, int(idxi)] Sj = Ij_hat[:, int(idxj)] + + # Using aspire PolarFT. Replace later + # Ii_hat = self.pf_full[i] + # Ij_hat = self.pf_full[j] + # idxi = np.round((alpha - np.pi / 2) * n_theta / 2 / np.pi) % n_theta + # idxj = np.round((-gamma - np.pi / 2) * n_theta / 2 / np.pi) % n_theta + + # Si = Ii_hat[int(idxi)] + # Sj = Ij_hat[int(idxj)] + # norm_new = np.linalg.norm(Si - Sj, 1) + return np.linalg.norm(Si - Sj, 1) if loss == "wemd": @@ -729,6 +748,7 @@ def ADMM_preprocessing(self, C, Lmax, N, Ngrid): # AE and bE for quaternion constraints AEq = xp.asarray(loadmat("data/Eq_constraints/AEqJ.mat")["AEq"]) AEqAEqtinv = xp.asarray(loadmat("data/Eq_constraints/AEqJ.mat")["AEqAEqtinv"]) + bEq = xp.zeros(17) bEq[:16] = xp.eye(4).reshape(-1) / 4 bEq[-1] = 1 From b8cbc7a9005dd54e91fbd1943581cc8744134d0a Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 7 Apr 2026 09:48:52 -0400 Subject: [PATCH 06/38] compute_fejer_weights method. Still need to optimize (very slow) --- src/aspire/abinitio/commonline_nug.py | 46 ++++++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index 27d2e134b3..41f48f35f2 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -41,7 +41,7 @@ def __init__( **kwargs, ): """ - Initialize object for estimating 3D orientations for molecules with C3 and C4 symmetry. + Initialize object for estimating 3D orientations for symmetric molecules. :param src: The source object of 2D denoised or class-averaged images with metadata :param symmetry: A string, ie. 'C3', indicating the symmetry type. @@ -763,6 +763,9 @@ def ADMM_preprocessing(self, C, Lmax, N, Ngrid): loadmat("data/Fejer/Ngrid=%i" % Ngrid + "/k=%i" % k + ".mat")["W1"] for k in range(1, Lmax + 1) ] + + W0k, W1k = self.compute_fejer_weights() + # AI_mat=np.zeros((Ngrid,D0+D1)) # for p in range(Ngrid): # w0=np.zeros(D0); w1=np.zeros(D1) @@ -864,6 +867,47 @@ def ADMM_preprocessing(self, C, Lmax, N, Ngrid): Sq, ) + def compute_fejer_weights(self): + SO3_grid = loadmat("data/SO3_grid.mat")["SO3"] + Ngrid = SO3_grid.shape[0] + start = 1 + + TT = [] + TTI = [] + for ell in range(start, self.Lmax + 1): + T, Tinv = self.complex2real(ell) + TT.append(T) + TTI.append(Tinv) + + def permutek_block(Ak, k): + dk = 2 * k + 1 + Pk = np.eye(dk) + for m in range(k): + for l in range(k - m): + Pk[(m + 2 * l, m + 2 * l + 1), :] = Pk[ + (m + 2 * l + 1, m + 2 * l), : + ] + AkP = Pk @ Ak @ Pk.T + return AkP[:k, :k], AkP[k:, k:] + + W0 = [] + W1 = [] + for k in range(start, self.Lmax + 1): + print(k) + W0k = np.zeros((Ngrid, k, k)) + W1k = np.zeros((Ngrid, k + 1, k + 1)) + + TkT = TT[k - start].T + TinvkT = TTI[k - start].T + + for p in range(Ngrid): + w = np.real(TkT @ self.WD(k, SO3_grid[p]).conj() @ TinvkT) + W0k[p], W1k[p] = permutek_block(w, k) + + W0.append(W0k) + W1.append(W1k) + return W0, W1 + ######################### # Euler Estimation Step # ######################### From eb0470c3b043408bd099875188ac663d7f414346 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 8 Apr 2026 15:48:19 -0400 Subject: [PATCH 07/38] vectorize Wigner matrix comps. Vectorize fejer_weights comp. --- src/aspire/abinitio/commonline_nug.py | 56 ++++++++++----------------- 1 file changed, 20 insertions(+), 36 deletions(-) diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index 41f48f35f2..ff0a590ae0 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -168,10 +168,7 @@ def fij(alpha, gamma, i, j, loss): BTK = [] for k in range(1, Lmax + 1): - dk = 2 * k + 1 - btk = np.zeros((dk, dk)) - for n in range(2 * T): - btk = btk + bT[n] * self.Wd(k, beta_grid[n]) + btk = np.sum(bT[:, None, None] * self.Wd(k, beta_grid), axis=0) BTK.append(btk.T) def fijhat_k(k, F): @@ -755,16 +752,7 @@ def ADMM_preprocessing(self, C, Lmax, N, Ngrid): bEq = xp.repeat(bEq[:, xp.newaxis], N * (N - 1) // 2, axis=1) # AI and bI - W0 = [ - loadmat("data/Fejer/Ngrid=%i" % Ngrid + "/k=%i" % k + ".mat")["W0"] - for k in range(1, Lmax + 1) - ] - W1 = [ - loadmat("data/Fejer/Ngrid=%i" % Ngrid + "/k=%i" % k + ".mat")["W1"] - for k in range(1, Lmax + 1) - ] - - W0k, W1k = self.compute_fejer_weights() + W0, W1 = self.compute_fejer_weights() # AI_mat=np.zeros((Ngrid,D0+D1)) # for p in range(Ngrid): @@ -888,7 +876,7 @@ def permutek_block(Ak, k): (m + 2 * l + 1, m + 2 * l), : ] AkP = Pk @ Ak @ Pk.T - return AkP[:k, :k], AkP[k:, k:] + return AkP[..., :k, :k], AkP[..., k:, k:] W0 = [] W1 = [] @@ -900,9 +888,8 @@ def permutek_block(Ak, k): TkT = TT[k - start].T TinvkT = TTI[k - start].T - for p in range(Ngrid): - w = np.real(TkT @ self.WD(k, SO3_grid[p]).conj() @ TinvkT) - W0k[p], W1k[p] = permutek_block(w, k) + w = np.real(TkT @ self.WD(k, SO3_grid).conj() @ TinvkT) + W0k, W1k = permutek_block(w, k) W0.append(W0k) W1.append(W1k) @@ -979,14 +966,15 @@ def find_gamma(Xm, beta, alpha): Jk[S + 1 :: 2] = -1 Jk[S - 1 :: -2] = -1 Jk = np.diag(Jk) + ws = self.Wd(S, beta) for i in range(N): + wi = ws[i] for j in range(i + 1, N): Di = np.exp(-1j * np.arange(-S, S + 1) * alpha[i]) Dj = np.exp(-1j * np.arange(-S, S + 1) * alpha[j]) Xijm = Xm[dk * i : dk * (i + 1), dk * j : dk * (j + 1)] DXijmD = np.diag(Di.conj()) @ Xijm @ np.diag(Dj) - wi = self.Wd(S, beta[i]) - wj = self.Wd(S, beta[j]) + wj = ws[j] C1 = ( wi[:, 0][:, None] @ (wj[:, 0][:, None].T) + Jk @ wi[:, 0][:, None] @ (wj[:, 0][:, None].T) @ Jk @@ -1125,10 +1113,7 @@ def compute_rank(self, sym, Lmax): rk = xp.zeros(Lmax) A = [] for k in range(1, Lmax + 1): - dk = 2 * k + 1 - Ak = np.zeros((dk, dk), dtype=np.complex128) - for s in range(S): - Ak += self.WD(k, sym_euler[s]) + Ak = np.sum(self.WD(k, sym_euler), axis=0) Ak = np.round(Ak / S, 6) A.append(Ak) rk[k - 1] = np.linalg.matrix_rank(Ak) @@ -1168,22 +1153,22 @@ def Symmetry_Euler(sym): def WD(self, J, euler): # compute Wigner D matrix - alpha = euler[0] - beta = euler[1] - gamma = euler[2] + alpha = euler[:, 0] + beta = euler[:, 1] + gamma = euler[:, 2] d = self.Wd(J, beta) - D = ( - np.diag(np.exp(-1j * alpha * np.arange(-J, J + 1))) - @ d - @ np.diag(np.exp(-1j * gamma * np.arange(-J, J + 1))) - ) + + m = np.arange(-J, J + 1) + left = np.exp(-1j * alpha[:, None] * m[None, :]) + right = np.exp(-1j * gamma[:, None] * m[None, :]) + D = left[:, :, None] * d * right[:, None, :] + return D @staticmethod def Wd(J, beta): # compute Wigner small d matrix - d = np.zeros((2 * J + 1, 2 * J + 1)) - + d = np.zeros((len(beta), 2 * J + 1, 2 * J + 1)) for m in range(-J, J + 1): for n in range(-J, J + 1): smin = max(0, m - n) @@ -1199,13 +1184,12 @@ def Wd(J, beta): * np.sqrt(factorial(J - n)) / factorial(J - n - s) ) - d[n + J, m + J] += ( + d[:, n + J, m + J] += ( mul * (-1) ** (n - m + s) * (np.cos(beta / 2)) ** (2 * J + m - n - 2 * s) * (np.sin(beta / 2)) ** (n - m + 2 * s) ) - return d @staticmethod From 1a8ecdeb99806d6cedb29ba792a45377e0fa7cde Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 16 Apr 2026 11:45:56 -0400 Subject: [PATCH 08/38] Add SO3 grid generation --- src/aspire/abinitio/commonline_nug.py | 38 ++++++++++++++++++++++++++- src/aspire/utils/__init__.py | 1 + 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index ff0a590ae0..01dc3fe995 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -10,6 +10,7 @@ from aspire.nufft import nufft from aspire.numeric import fft, xp from aspire.operators import PolarFT, wemd_embed +from aspire.utils import cart2sph logger = logging.getLogger(__name__) @@ -785,6 +786,18 @@ def ADMM_preprocessing(self, C, Lmax, N, Ngrid): # this needs double checking AI_mat_offdiag[p, : d0[-1]] = w0 AI_mat_offdiag[p, d0[-1] :] = w1 + + # Vectorized version, added by Josh + # AI_mat_offdiag_new = np.zeros((Ngrid, D0 + D1)) + # for k in range(1, Lmax + 1): + # scale = (Lmax - k + 2) * (Lmax - k + 1) * (k + 0.5) + + # block0 = scale * W0[k - 1].transpose(0, 2, 1).reshape(Ngrid, -1) + # block1 = scale * W1[k - 1].transpose(0, 2, 1).reshape(Ngrid, -1) + + # AI_mat_offdiag_new[:, d0[k - 1]:d0[k]] = block0 + # AI_mat_offdiag_new[:, d0[-1] + d1[k - 1]:d0[-1] + d1[k]] = block1 + AI_mat_diag = np.zeros((Ngrid, D0 + D1)) for p in range(Ngrid): w0 = np.zeros(D0) @@ -856,7 +869,7 @@ def ADMM_preprocessing(self, C, Lmax, N, Ngrid): ) def compute_fejer_weights(self): - SO3_grid = loadmat("data/SO3_grid.mat")["SO3"] + SO3_grid = self.discretize_SO3() Ngrid = SO3_grid.shape[0] start = 1 @@ -895,6 +908,29 @@ def permutek_block(Ak, k): W1.append(W1k) return W0, W1 + def discretize_SO3(self): + S2 = loadmat("design20.mat")["design"] + S2_size = S2.shape[0] + + # discretize S1 + S1_size = round(np.sqrt(np.pi * S2_size)) + alpha = np.linspace(0, 2 * np.pi, S1_size) + + # discretize S2 + gamma, beta, _ = cart2sph(S2[:, 0], S2[:, 1], S2[:, 2]) + beta = np.pi / 2 - beta + gamma = gamma + np.pi + + # SO(3) in Euler ZYZ + SO3 = np.zeros((S2_size * S1_size, 3)) + count = 0 + for i in range(S1_size): + for j in range(S2_size): + SO3[count] = [alpha[i], beta[j], gamma[j]] + count += 1 + + return SO3 + ######################### # Euler Estimation Step # ######################### diff --git a/src/aspire/utils/__init__.py b/src/aspire/utils/__init__.py index ae896ebeb8..b7bd2d1947 100644 --- a/src/aspire/utils/__init__.py +++ b/src/aspire/utils/__init__.py @@ -2,6 +2,7 @@ from .coor_trans import ( # isort:skip mean_aligned_angular_distance, cart2pol, + cart2sph, crop_pad_2d, crop_pad_3d, grid_1d, From 84b7fba64abbea9f0d0ea3e72492f7b83b87602c Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 17 Apr 2026 10:46:35 -0400 Subject: [PATCH 09/38] tox --- src/aspire/abinitio/commonline_nug.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index 01dc3fe995..14d8cded0b 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -884,9 +884,9 @@ def permutek_block(Ak, k): dk = 2 * k + 1 Pk = np.eye(dk) for m in range(k): - for l in range(k - m): - Pk[(m + 2 * l, m + 2 * l + 1), :] = Pk[ - (m + 2 * l + 1, m + 2 * l), : + for ell in range(k - m): + Pk[(m + 2 * ell, m + 2 * ell + 1), :] = Pk[ + (m + 2 * ell + 1, m + 2 * ell), : ] AkP = Pk @ Ak @ Pk.T return AkP[..., :k, :k], AkP[..., k:, k:] @@ -894,7 +894,6 @@ def permutek_block(Ak, k): W0 = [] W1 = [] for k in range(start, self.Lmax + 1): - print(k) W0k = np.zeros((Ngrid, k, k)) W1k = np.zeros((Ngrid, k + 1, k + 1)) From 4255a33170d4108c139fcd84820a95383da683a3 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 20 Apr 2026 11:14:38 -0400 Subject: [PATCH 10/38] use aspire symmetry parsing --- src/aspire/abinitio/commonline_nug.py | 59 +++++++-------------------- 1 file changed, 14 insertions(+), 45 deletions(-) diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index 14d8cded0b..abedcd5431 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -11,6 +11,7 @@ from aspire.numeric import fft, xp from aspire.operators import PolarFT, wemd_embed from aspire.utils import cart2sph +from aspire.volume import SymmetryGroup logger = logging.getLogger(__name__) @@ -70,7 +71,11 @@ def __init__( self.mult = mult self.Ngrid = Ngrid self.Nstep_yI = Nstep_yI - self.sym = symmetry + + # Handle symmetry + self.sym_grp = SymmetryGroup.parse(symmetry) + self.sym_euler = self.sym_grp.rotations.angles + self.n_sym = len(self.sym_euler) self._build_full_pft() @@ -79,11 +84,9 @@ def _build_full_pft(self): self.pf_full = PolarFT.half_to_full(pf) def estimate_rotations(self): - sym_euler, S = self.Symmetry_Euler(self.sym) imgs = self.src.images[:] C = self.compute_coeff(imgs, self.loss, self.Lmax, T=self.T) X_est = self.admm_sym_J( - self.sym, C, self.Lmax, self.n_img, @@ -96,7 +99,7 @@ def estimate_rotations(self): self.Nstep_yI, ) - R_est, Euler_est = self.euler_est(X_est[0], X_est[S - 1], self.sym, self.n_img) + R_est, Euler_est = self.euler_est(X_est[0], X_est[self.n_sym - 1]) self.rotations = R_est return R_est @@ -292,7 +295,6 @@ def complex2real(ell): def admm_sym_J( self, - sym, C, Lmax, N, @@ -333,8 +335,7 @@ def admm_sym_J( Sq, ) = self.ADMM_preprocessing(C, Lmax, N, Ngrid) - # S=int(sym[1]); sym_euler=Symmetry_Euler(sym) - rank_Ak, _ = self.compute_rank(sym, Lmax) + rank_Ak, _ = self.compute_rank(Lmax) logger.info(f"Rank of Ak: {rank_Ak}") # rank_Ak=cp.zeros(Lmax) @@ -934,8 +935,9 @@ def discretize_SO3(self): # Euler Estimation Step # ######################### - def euler_est(self, X1, XS, sym, N): - S = int(sym[1]) + def euler_est(self, X1, XS): + S = self.n_sym + N = self.n_img sym_euler = np.zeros((S, 3)) for s in range(S): sym_euler[s] = [2 * np.pi * s / S, 0, 0] @@ -1143,49 +1145,16 @@ def largest_eigenvalue(AI, Ngrid, N): # Lambda=xp.linalg.eigvalsh(AI@AI.T)[-1]; print(Lambda) return Lambda - def compute_rank(self, sym, Lmax): - sym_euler, S = self.Symmetry_Euler(sym) + def compute_rank(self, Lmax): rk = xp.zeros(Lmax) A = [] for k in range(1, Lmax + 1): - Ak = np.sum(self.WD(k, sym_euler), axis=0) - Ak = np.round(Ak / S, 6) + Ak = np.sum(self.WD(k, self.sym_euler), axis=0) + Ak = np.round(Ak / self.n_sym, 6) A.append(Ak) rk[k - 1] = np.linalg.matrix_rank(Ak) return rk, A - @staticmethod - def Symmetry_Euler(sym): - if sym[0] == "C": - order = int(sym[1:]) - sym_euler = np.zeros((order, 3)) - for i in range(order): - sym_euler[i] = spr.from_euler( - "zyx", [2 * np.pi / order * i, 0, 0] - ).as_euler("zyz") - - if sym[0] == "D": - order = int(sym[1:]) - sym_euler = np.zeros((2 * order, 3)) - for i in range(order): - sym_euler[i] = spr.from_euler( - "zyx", [2 * np.pi / order * i, 0, 0] - ).as_euler("zyz") - sym_euler[i + order] = spr.from_euler( - "zyx", [2 * np.pi / order * i, 0, np.pi] - ).as_euler("zyz") - - if sym == "T12": - sym_euler = np.zeros((12, 3)) - sym_euler[1] = spr.from_rotvec( - 2 * np.pi / 3 * np.array([0, 0, 1]) - ).as_euler("zyz") - sym_euler[2] = spr.from_rotvec( - 4 * np.pi / 3 * np.array([0, 0, 1]) - ).as_euler("zyz") - - return sym_euler, sym_euler.shape[0] - def WD(self, J, euler): # compute Wigner D matrix alpha = euler[:, 0] From b7a15734c6b4f86c07b83dc46e8657142e3b55bb Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 20 Apr 2026 13:31:19 -0400 Subject: [PATCH 11/38] Use aspire ZYZ rotation convention --- src/aspire/abinitio/commonline_nug.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index abedcd5431..2f9d8a9c19 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -10,7 +10,7 @@ from aspire.nufft import nufft from aspire.numeric import fft, xp from aspire.operators import PolarFT, wemd_embed -from aspire.utils import cart2sph +from aspire.utils import Rotation, cart2sph from aspire.volume import SymmetryGroup logger = logging.getLogger(__name__) @@ -1041,11 +1041,7 @@ def find_gamma(Xm, beta, alpha): Euler_est[:, 0] = find_alpha(X1) Euler_est[:, 1] = find_beta(X1) Euler_est[:, 2] = find_gamma(XS, Euler_est[:, 1], Euler_est[:, 0]) - R_est = np.zeros((N, 3, 3)) - for n in range(N): - R_est[n] = spr.from_euler("zyz", np.flip(Euler_est[n])).as_matrix().T - # note that the order of alpha and gamma are swapped due to convention - + R_est = Rotation.from_euler(Euler_est).matrices.transpose(0, 2, 1) return R_est, Euler_est #################### From 3ae7c2a5dceb449d9f295254d2e0bc70818bd431 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 21 Apr 2026 16:06:49 -0400 Subject: [PATCH 12/38] vectorize psd_projection, transform_block, transform_back_block --- src/aspire/abinitio/commonline_nug.py | 42 +++++++++++++++++++-------- 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index 2f9d8a9c19..cf8647d064 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -3,7 +3,6 @@ import numpy as np from scipy.io import loadmat -from scipy.spatial.transform import Rotation as spr from scipy.special import factorial from aspire.abinitio import CLOrient3D @@ -335,6 +334,7 @@ def admm_sym_J( Sq, ) = self.ADMM_preprocessing(C, Lmax, N, Ngrid) + n_pairs = N * (N - 1) // 2 rank_Ak, _ = self.compute_rank(Lmax) logger.info(f"Rank of Ak: {rank_Ak}") @@ -473,10 +473,10 @@ def update_S(C0, C1, yE, yEq, yI, X0, X1, Xd0, Xd1, Xq, rho, Lmax, N): ) toc1 = time.perf_counter() Time[1] += toc1 - tic1 + tic2 = time.perf_counter() - for count in range(N * (N - 1) // 2): - tmp = self.psd_projection(Sq[:, count].reshape(4, 4).T) - Sq[:, count] = tmp.T.reshape(16) + Sq = self.psd_projection(Sq.T.reshape(n_pairs, 4, 4)) + Sq = Sq.T.reshape(-1, n_pairs) toc2 = time.perf_counter() Time[2] += toc2 - tic2 return S0, S1, Sd0, Sd1, Sq @@ -822,7 +822,7 @@ def ADMM_preprocessing(self, C, Lmax, N, Ngrid): AI_mat_diag = xp.asarray(AI_mat_diag) / 1 AI_mat_offdiag = xp.asarray(AI_mat_offdiag) / 1 bI = -(Lmax + 2) * (Lmax + 1) / 2 / 1 - np.save("AI_mat_offdiag_aspire.npy", AI_mat_offdiag) + # largest eigenvalue for AIAIT Lambda = self.largest_eigenvalue(AI_mat_offdiag, Ngrid, N) @@ -1203,12 +1203,17 @@ def mat_block(vecA, N, sz, IDX_upper, IDX_lower, idx_offdiag): @staticmethod def psd_projection(B): # compute the PSD part of a symmstric matrix - evals, evecs = xp.linalg.eigh((B + B.T) / 2) + B_sym = (B + B.swapaxes(-1, -2)) / 2 + evals, evecs = xp.linalg.eigh(B_sym) evals = xp.maximum(evals, 0) - return (evecs * evals) @ evecs.T + return (evecs * evals[..., None, :]) @ evecs.swapaxes(-1, -2) @staticmethod def transform_block(A, k, Pk=None): + single = A.ndim == 2 + if single: + A = A[None, :, :] + if Pk is None: dk = 2 * k + 1 Pk = xp.eye(dk) @@ -1218,14 +1223,26 @@ def transform_block(A, k, Pk=None): (m + 2 * el + 1, m + 2 * el), : ] AT = Pk @ A @ Pk.T - return AT[:k, :k].T.reshape(-1), AT[k:, k:].T.reshape(-1) + A0 = AT[:, :k, :k].swapaxes(-1, -2).reshape(A.shape[0], -1) + A1 = AT[:, k:, k:].swapaxes(-1, -2).reshape(A.shape[0], -1) + + if single: + return A0[0], A1[0] + + return A0, A1 @staticmethod def transform_back_block(A0, A1, k, Pk=None): dk = 2 * k + 1 - A = xp.zeros((dk, dk)) - A[:k, :k] = A0.reshape(k, k).T - A[k:, k:] = A1.reshape(k + 1, k + 1).T + single = A0.ndim == 1 + + if single: + A0 = A0[None, :] + A1 = A1[None, :] + + A = xp.zeros((A0.shape[0], dk, dk)) + A[:, :k, :k] = A0.reshape(-1, k, k).swapaxes(-1, -2) + A[:, k:, k:] = A1.reshape(-1, k + 1, k + 1).swapaxes(-1, -2) if Pk is None: Pk = xp.eye(dk) for m in range(k): @@ -1233,4 +1250,5 @@ def transform_back_block(A0, A1, k, Pk=None): Pk[(m + 2 * el, m + 2 * el + 1), :] = Pk[ (m + 2 * el + 1, m + 2 * el), : ] - return Pk.T @ A @ Pk + out = Pk.T @ A @ Pk + return out[0] if single else out From 3d0b0c28fe5b66399f046809c8a6aee84244fce1 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 22 Apr 2026 11:12:27 -0400 Subject: [PATCH 13/38] Vectorize update_S. 2x speedup! --- src/aspire/abinitio/commonline_nug.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index cf8647d064..27f7f48912 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -458,19 +458,24 @@ def update_S(C0, C1, yE, yEq, yI, X0, X1, Xd0, Xd1, Xq, rho, Lmax, N): S1[d1[k - 1] : d1[k], :] = self.vec_block(tmp, N, k + 1, IDX_upper) toc0 = time.perf_counter() Time[0] += toc0 - tic0 + + # See if we can switch everything to C order before here. + Sd0 = Sd0.T + Sd1 = Sd1.T tic1 = time.perf_counter() for k in range(1, Lmax + 1): - for n in range(N): - tmp = self.transform_back_block( - Sd0[d0[k - 1] : d0[k], n], - Sd1[d1[k - 1] : d1[k], n], - k, - P[k - 1], - ) - tmp = self.psd_projection(tmp) - Sd0[d0[k - 1] : d0[k], n], Sd1[d1[k - 1] : d1[k], n] = ( - self.transform_block(tmp, k, P[k - 1]) - ) + tmp = self.transform_back_block( + Sd0[:, d0[k - 1] : d0[k]], + Sd1[:, d1[k - 1] : d1[k]], + k, + P[k - 1], + ) + tmp = self.psd_projection(tmp) + Sd0[:, d0[k - 1] : d0[k]], Sd1[:, d1[k - 1] : d1[k]] = ( + self.transform_block(tmp, k, P[k - 1]) + ) + Sd0 = Sd0.T + Sd1 = Sd1.T toc1 = time.perf_counter() Time[1] += toc1 - tic1 From 61affef8d9af0d20b0575a5bdb7be6fa9f6d0902 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 22 Apr 2026 12:02:32 -0400 Subject: [PATCH 14/38] vectorize print update component. --- src/aspire/abinitio/commonline_nug.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index 27f7f48912..45df6798b1 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -611,15 +611,16 @@ def print_updates(verbose=True): res_psdX = res_psdX / (1 + np.linalg.norm(X0) + np.linalg.norm(X1)) res_psdD = 0 for k in range(1, Lmax + 1): - for n in range(N): - tmp = self.transform_back_block( - Xd0[d0[k - 1] : d0[k], n], - Xd1[d1[k - 1] : d1[k], n], - k, - P[k - 1], - ) - res_psdD += np.linalg.norm(self.psd_projection(-tmp)) - # res_psdD+=norm(self.psd_projection(-tmp))/(1+norm(tmp)) + tmp = self.transform_back_block( + Xd0[d0[k - 1] : d0[k]].T, + Xd1[d1[k - 1] : d1[k]].T, + k, + P[k - 1], + ) + res_psdD += np.linalg.norm( + self.psd_projection(-tmp), axis=(-2, -1) + ).sum() + # res_psdD+=norm(self.psd_projection(-tmp))/(1+norm(tmp)) res_psdD = res_psdD / (1 + np.linalg.norm(Xd0) + np.linalg.norm(Xd1)) res_psdQ = 0 for count in range(N * (N - 1) // 2): From 20186fcbcd28eaa191dd2a4910118c598c2214b7 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 24 Apr 2026 15:36:19 -0400 Subject: [PATCH 15/38] intial add of proximal_refine --- src/aspire/abinitio/commonline_nug.py | 97 +++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index 45df6798b1..ccd0f33226 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -39,6 +39,7 @@ def __init__( mult=1.5, Ngrid=16317, Nstep_yI=10, + perform_pr=False, **kwargs, ): """ @@ -70,6 +71,7 @@ def __init__( self.mult = mult self.Ngrid = Ngrid self.Nstep_yI = Nstep_yI + self.perform_pr = perform_pr # Handle symmetry self.sym_grp = SymmetryGroup.parse(symmetry) @@ -98,6 +100,26 @@ def estimate_rotations(self): self.Nstep_yI, ) + if self.perform_pr: + weight = 1 / (1 + np.arange(self.Lmax)) + Penalty = [1, 1, 1, 1] + r = [3, 2, 1, 0] + X_est = self.proximal_refine( + X_est, + C, + self.n_img, + weight, + Penalty, + r, + self.Ngrid, + self.max_iter, + self.rho, + self.ratio, + self.factor, + self.mult, + self.Nstep_yI, + ) + R_est, Euler_est = self.euler_est(X_est[0], X_est[self.n_sym - 1]) self.rotations = R_est @@ -937,6 +959,81 @@ def discretize_SO3(self): return SO3 + ############################ + # Proximal Refinement Step # + ############################ + + def proximal_refine( + self, + X_admm, + C, + N, + weight, + Penalty, + r, + Ngrid, + max_iter, + rho, + ratio, + factor, + mult, + Nstep_yI, + verbose=True, + ): + + def Ak(J, Euler): + # compute Ak matrix + order = Euler.shape[0] + # A = np.zeros((2 * J + 1, 2 * J + 1), dtype=np.complex128) + # for i in range(order): + # A += self.WD(J, Euler[i]) + A = self.WD(J, Euler).sum(axis=0) + return np.round(A / order, 10) + + rank_Ak = xp.zeros(self.Lmax) + for k in range(self.Lmax): + C[k] = xp.asnumpy(C[k]) + rank_Ak[k] = np.linalg.matrix_rank( + Ak(k + 1, self.sym_euler), tol=1e-6, hermitian=True + ) + + def low_rank_proj(X, r): + Xproj = [] + for k in range(self.Lmax): + dk = 2 * k + 3 + rk = int(rank_Ak[k] * 2) + r + tmp = np.copy(X[k]) + for i in range(N): + u, s, v = np.linalg.svd( + tmp[i * dk : (i + 1) * dk, i * dk : (i + 1) * dk] + ) + tmp[i * dk : (i + 1) * dk, i * dk : (i + 1) * dk] = ( + u[:, :rk] @ np.diag(s[:rk]) @ v[:rk] + ) + Xproj.append(tmp) + return Xproj + + Niter = len(r) + CC = [None] * self.Lmax + for iter in range(Niter): + X_prox = low_rank_proj(X_admm, r[iter]) + for k in range(self.Lmax): + CC[k] = C[k] - Penalty[iter] * weight[k] * (X_prox[k] + X_prox[k].T) / 2 + X_prox = self.admm_sym_J( + CC, + self.Lmax, + N, + Ngrid, + max_iter, + rho, + ratio, + factor, + mult, + Nstep_yI, + verbose, + ) + return X_prox + ######################### # Euler Estimation Step # ######################### From 43b2d44c1e3bf293e10949ff13c309c390c88d70 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 30 Apr 2026 14:31:03 -0400 Subject: [PATCH 16/38] Update base class. Fix Proximal Refinement bug. Add PR update diagnostic log. --- src/aspire/abinitio/commonline_nug.py | 55 +++++++++++++++++++-------- 1 file changed, 39 insertions(+), 16 deletions(-) diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index ccd0f33226..3998dfc90a 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -5,7 +5,7 @@ from scipy.io import loadmat from scipy.special import factorial -from aspire.abinitio import CLOrient3D +from aspire.abinitio import Orient3D from aspire.nufft import nufft from aspire.numeric import fft, xp from aspire.operators import PolarFT, wemd_embed @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) -class CommonlineNUG(CLOrient3D): +class CommonlineNUG(Orient3D): """ Class to estimate 3D orientations using non-uqique games. """ @@ -980,28 +980,33 @@ def proximal_refine( Nstep_yI, verbose=True, ): - def Ak(J, Euler): # compute Ak matrix order = Euler.shape[0] - # A = np.zeros((2 * J + 1, 2 * J + 1), dtype=np.complex128) - # for i in range(order): - # A += self.WD(J, Euler[i]) A = self.WD(J, Euler).sum(axis=0) return np.round(A / order, 10) - rank_Ak = xp.zeros(self.Lmax) + def rel_change(A, B, eps=1e-12): + num = 0.0 + den = 0.0 + for k in range(len(A)): + num += np.linalg.norm(A[k] - B[k]) ** 2 + den += np.linalg.norm(B[k]) ** 2 + return np.sqrt(num) / max(np.sqrt(den), eps) + + rank_Ak = np.zeros(self.Lmax) + C_base = [None] * self.Lmax for k in range(self.Lmax): - C[k] = xp.asnumpy(C[k]) + C_base[k] = xp.asnumpy(C[k]).copy() rank_Ak[k] = np.linalg.matrix_rank( Ak(k + 1, self.sym_euler), tol=1e-6, hermitian=True ) - def low_rank_proj(X, r): + def low_rank_proj(X, r_step): Xproj = [] for k in range(self.Lmax): dk = 2 * k + 3 - rk = int(rank_Ak[k] * 2) + r + rk = min(int(rank_Ak[k] * 2) + r_step, dk) tmp = np.copy(X[k]) for i in range(N): u, s, v = np.linalg.svd( @@ -1015,11 +1020,18 @@ def low_rank_proj(X, r): Niter = len(r) CC = [None] * self.Lmax - for iter in range(Niter): - X_prox = low_rank_proj(X_admm, r[iter]) + current = [np.copy(Xk) for Xk in X_admm] + + for step in range(Niter): + X_proj = low_rank_proj(current, r[step]) + for k in range(self.Lmax): - CC[k] = C[k] - Penalty[iter] * weight[k] * (X_prox[k] + X_prox[k].T) / 2 - X_prox = self.admm_sym_J( + CC[k] = ( + C_base[k] + - Penalty[step] * weight[k] * (X_proj[k] + X_proj[k].T) / 2 + ) + + X_next = self.admm_sym_J( CC, self.Lmax, N, @@ -1030,9 +1042,20 @@ def low_rank_proj(X, r): factor, mult, Nstep_yI, - verbose, + verbose=False, ) - return X_prox + + if verbose: + logger.info( + "Proximal refine step %d/%d: relative update %.3e", + step + 1, + Niter, + rel_change(X_next, current), + ) + + current = [np.copy(Xk) for Xk in X_next] + + return current ######################### # Euler Estimation Step # From 83fc0f4ddaeeae8b9559be616ef9b4c29896eb8a Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 5 May 2026 14:23:49 -0400 Subject: [PATCH 17/38] Add euler_est_Dm --- src/aspire/abinitio/commonline_nug.py | 151 +++++++++++++++++++++++++- 1 file changed, 149 insertions(+), 2 deletions(-) diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index 3998dfc90a..5635aa6edf 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -10,7 +10,7 @@ from aspire.numeric import fft, xp from aspire.operators import PolarFT, wemd_embed from aspire.utils import Rotation, cart2sph -from aspire.volume import SymmetryGroup +from aspire.volume import CnSymmetryGroup, DnSymmetryGroup, SymmetryGroup logger = logging.getLogger(__name__) @@ -120,7 +120,11 @@ def estimate_rotations(self): self.Nstep_yI, ) - R_est, Euler_est = self.euler_est(X_est[0], X_est[self.n_sym - 1]) + if isinstance(self.sym_grp, CnSymmetryGroup): + R_est, Euler_est = self.euler_est(X_est[0], X_est[self.n_sym - 1]) + elif isinstance(self.sym_grp, DnSymmetryGroup): + R_est, Euler_est = self.euler_est_Dm(X_est) + self.rotations = R_est return R_est @@ -1170,6 +1174,149 @@ def find_gamma(Xm, beta, alpha): R_est = Rotation.from_euler(Euler_est).matrices.transpose(0, 2, 1) return R_est, Euler_est + def euler_est_Dm(self, X_est): + + X2 = X_est[1] + S = self.sym_grp.order + N = self.n_img + XS = X_est[S - 1] + dk = 2 * S + 1 + + def find_alpha_beta(X2): + def find_phase(A, B): + # find a number c that minimizes ||cA-B||_F + Ar = np.real(A) + Ai = np.imag(A) + Br = np.real(B) + Bi = np.imag(B) + c = (np.vdot(Ar, Br) + np.vdot(Ai, Bi)) / ( + np.vdot(Ar, Ar) + np.vdot(Ai, Ai) + ) + 1j * (np.vdot(Ar, Bi) - np.vdot(Ai, Br)) / ( + np.vdot(Ar, Ar) + np.vdot(Ai, Ai) + ) + return c / abs(c) + + T, Tinv = self.complex2real(2) + X2 = np.kron(np.eye(N), T) @ X2 @ np.kron(np.eye(N), Tinv) + + B1 = np.zeros((N, N)) + for i in range(N): + for j in range(N): + Xij = X2[5 * i : 5 * (i + 1), 5 * j : 5 * (j + 1)] + B1[i, j] = np.real((Xij[2, 2] - 2 * abs(Xij[0, 0]) + 0.5) / 3 * 2) + e1, v1 = np.linalg.eigh(B1) + idx = np.argmax(e1) + b1 = v1[:, idx] * np.sqrt(e1[idx]) * np.sign(v1[0, idx]) + beta_est = np.arccos(np.clip(np.sqrt(b1), -1, 1)) % np.pi + + Aminus = np.zeros((N, N), dtype=complex) + Aplus = np.zeros((N, N), dtype=complex) + for i in range(N): + for j in range(N): + Xij = X2[5 * i : 5 * (i + 1), 5 * j : 5 * (j + 1)] + if abs(beta_est[i]) < 1e-6: + beta_est[i] = 1e-6 + if abs(beta_est[j]) < 1e-6: + beta_est[j] = 1e-6 + Aminus[i, j] = ( + Xij[1, 1] + / np.sin(2 * beta_est[i]) + / np.sin(2 * beta_est[j]) + * 8 + / 3 + ) + Aplus[i, j] = ( + -Xij[1, 3] + / np.sin(2 * beta_est[i]) + / np.sin(2 * beta_est[j]) + * 8 + / 3 + ) + + evals, evecs = np.linalg.eigh(Aminus) + idx = np.argmax(abs(evals)) + Z = evecs[:, idx] * np.sqrt(abs(evals[idx])) + c = find_phase(Z[:, None] @ Z[:, None].T, Aplus) + Z = np.sqrt(c) * Z + alpha_est = (np.angle(Z)) % (2 * np.pi) + + return alpha_est, beta_est + + def find_gamma(Xm, alpha, beta): + def LS_D(W1, W2, W3, W4, Br, Bi): + A = np.array( + [ + [np.vdot(W1 + W4, W1 + W4), np.vdot(W1 + W4, W2 + W3)], + [np.vdot(W2 + W3, W1 + W4), np.vdot(W2 + W3, W2 + W3)], + ] + ) + B = np.array([np.vdot(W1 + W4, Br), np.vdot(W2 + W3, Br)]) + a, c = np.linalg.lstsq(A, B)[0] + + A = np.array( + [ + [np.vdot(W1 - W4, W1 - W4), np.vdot(W1 - W4, W3 - W2)], + [np.vdot(W1 - W4, W3 - W2), np.vdot(W3 - W2, W3 - W2)], + ] + ) + B = np.array([np.vdot(W1 - W4, Bi), np.vdot(W3 - W2, Bi)]) + b, d = np.linalg.lstsq(A, B)[0] + return a + 1j * b + + [T, Tinv] = self.complex2real(S) + Xm = np.kron(np.eye(N), T) @ Xm @ np.kron(np.eye(N), Tinv) + C = np.zeros((N, N), dtype=complex) + Jk = np.ones(dk) + Jk[S + 1 :: 2] = -1 + Jk[S - 1 :: -2] = -1 + Jk = np.diag(Jk) + ws = self.Wd(S, beta) + for i in range(N): + wi = ws[i] + for j in range(i + 1, N): + Di = np.exp(-1j * np.arange(-S, S + 1) * alpha[i]) + Dj = np.exp(-1j * np.arange(-S, S + 1) * alpha[j]) + Xijm = Xm[dk * i : dk * (i + 1), dk * j : dk * (j + 1)] + DXijmD = np.diag(Di.conj()) @ Xijm @ np.diag(Dj) + wj = ws[j] + W1 = ( + wi[:, 0][:, np.newaxis] @ wj[:, 0][:, np.newaxis].T + + Jk @ wi[:, 0][:, np.newaxis] @ wj[:, 0][:, np.newaxis].T @ Jk + ) + W2 = ( + wi[:, -1][:, np.newaxis] @ wj[:, 0][:, np.newaxis].T + + Jk @ wi[:, -1][:, np.newaxis] @ wj[:, 0][:, np.newaxis].T @ Jk + ) + W3 = ( + wi[:, 0][:, np.newaxis] @ wj[:, -1][:, np.newaxis].T + + Jk @ wi[:, 0][:, np.newaxis] @ wj[:, -1][:, np.newaxis].T @ Jk + ) + W4 = ( + wi[:, -1][:, np.newaxis] @ wj[:, -1][:, np.newaxis].T + + Jk + @ wi[:, -1][:, np.newaxis] + @ wj[:, -1][:, np.newaxis].T + @ Jk + ) + Br = np.real(4 * DXijmD) + Bi = np.imag(4 * DXijmD) + C[i, j] = LS_D(W1, W2, W3, W4, Br, Bi) + C += C.T.conj() + np.eye(N) + evals, evecs = np.linalg.eigh(C) + idx = np.argmax(evals) + c = evecs[:, idx] * np.sqrt(evals[idx]) + return (np.angle(c) / S) % (2 * np.pi) + + alpha_est, beta_est = find_alpha_beta(X2) + gamma_est = find_gamma(XS, alpha_est, beta_est) + Euler_est = np.zeros((N, 3)) + Euler_est[:, 0] = alpha_est + Euler_est[:, 1] = beta_est + Euler_est[:, 2] = gamma_est + R_est = Rotation.from_euler(Euler_est).matrices.transpose(0, 2, 1) + + return R_est, Euler_est + #################### # Helper Functions # #################### From 6cb35c5176e29f6de629201e74f33ffa406fe0e7 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 13 May 2026 15:32:07 -0400 Subject: [PATCH 18/38] Use in-house polarFT and Sinogram. Add symmetry handling and logs --- src/aspire/abinitio/commonline_nug.py | 48 +++++++++++++++++---------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index 5635aa6edf..ece406f0bb 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -74,7 +74,18 @@ def __init__( self.perform_pr = perform_pr # Handle symmetry - self.sym_grp = SymmetryGroup.parse(symmetry) + if symmetry is None: + logger.info( + f"Symmetry not provided. Using Source symmetry: {str(self.src.symmetry_group)}" + ) + self.sym_grp = self.src.symmetry_group + else: + if symmetry != str(self.src.symmetry_group): + logger.info( + f"Provided symmetry, {symmetry}, does not match source, {str(self.src.symmetry_group)}" + ) + logger.info(f"Using provided symmetry: {symmetry}") + self.sym_grp = SymmetryGroup.parse(symmetry) self.sym_euler = self.sym_grp.rotations.angles self.n_sym = len(self.sym_euler) @@ -142,11 +153,12 @@ def compute_coeff(self, Img, loss, Lmax, T): Img_pft = np.zeros((L, n_theta, N), dtype=complex) # Replace with Image.project() later - Img = Img.asnumpy() - for n in range(N): - line_proj[:, :, n], Img_pft[:, :, n] = self.fast_radon_transform( - Img[n], angular_sampling - ) + line_proj = Img.project(angular_sampling).asnumpy().T + # Img = Img.asnumpy() + # for n in range(N): + # line_proj[:, :, n], Img_pft[:, :, n] = self.fast_radon_transform( + # Img[n], angular_sampling + # ) dim_wave = len(wemd_embed(line_proj[:, 0, 0])) WE = np.zeros((dim_wave, n_theta, N)) @@ -156,22 +168,22 @@ def compute_coeff(self, Img, loss, Lmax, T): def fij(alpha, gamma, i, j, loss): if loss == "l1": - Ii_hat = Img_pft[:, :, i] - Ij_hat = Img_pft[:, :, j] - idxi = np.round((alpha - np.pi / 2) * n_theta / 2 / np.pi) % n_theta - idxj = np.round((-gamma - np.pi / 2) * n_theta / 2 / np.pi) % n_theta + # Ii_hat = Img_pft[:, :, i] + # Ij_hat = Img_pft[:, :, j] + # idxi = np.round((alpha - np.pi / 2) * n_theta / 2 / np.pi) % n_theta + # idxj = np.round((-gamma - np.pi / 2) * n_theta / 2 / np.pi) % n_theta - Si = Ii_hat[:, int(idxi)] - Sj = Ij_hat[:, int(idxj)] + # Si = Ii_hat[:, int(idxi)] + # Sj = Ij_hat[:, int(idxj)] # Using aspire PolarFT. Replace later - # Ii_hat = self.pf_full[i] - # Ij_hat = self.pf_full[j] - # idxi = np.round((alpha - np.pi / 2) * n_theta / 2 / np.pi) % n_theta - # idxj = np.round((-gamma - np.pi / 2) * n_theta / 2 / np.pi) % n_theta + Ii_hat = self.pf_full[i] + Ij_hat = self.pf_full[j] + idxi = np.round((alpha - np.pi / 2) * n_theta / 2 / np.pi) % n_theta + idxj = np.round((-gamma - np.pi / 2) * n_theta / 2 / np.pi) % n_theta - # Si = Ii_hat[int(idxi)] - # Sj = Ij_hat[int(idxj)] + Si = Ii_hat[int(idxi)] + Sj = Ij_hat[int(idxj)] # norm_new = np.linalg.norm(Si - Sj, 1) return np.linalg.norm(Si - Sj, 1) From 0295cb35e199d8e381690cc68fda9d17f50c9fa4 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 13 May 2026 15:46:44 -0400 Subject: [PATCH 19/38] cleanup --- src/aspire/abinitio/commonline_nug.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index ece406f0bb..23624aae1f 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -149,11 +149,12 @@ def compute_coeff(self, Img, loss, Lmax, T): N, L, _ = Img.shape n_theta = 360 angular_sampling = np.arange(0, 360, 1) - line_proj = np.zeros((L, n_theta, N)) - Img_pft = np.zeros((L, n_theta, N), dtype=complex) - # Replace with Image.project() later + # Using ASPIRE Image.project(). Leaving original method in comments for now. line_proj = Img.project(angular_sampling).asnumpy().T + + # line_proj = np.zeros((L, n_theta, N)) + # Img_pft = np.zeros((L, n_theta, N), dtype=complex) # Img = Img.asnumpy() # for n in range(N): # line_proj[:, :, n], Img_pft[:, :, n] = self.fast_radon_transform( From 9eb4ced7f910524feaf64597220be15b21fab1eb Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 14 May 2026 14:26:39 -0400 Subject: [PATCH 20/38] Use saff-kuijlaars method for S2 grid --- src/aspire/abinitio/__init__.py | 1 + src/aspire/abinitio/commonline_d2.py | 35 ++----------------------- src/aspire/abinitio/commonline_nug.py | 26 +++++++++--------- src/aspire/abinitio/commonline_utils.py | 31 ++++++++++++++++++++++ 4 files changed, 46 insertions(+), 47 deletions(-) diff --git a/src/aspire/abinitio/__init__.py b/src/aspire/abinitio/__init__.py index 4b6248c4c7..cb68c8f544 100644 --- a/src/aspire/abinitio/__init__.py +++ b/src/aspire/abinitio/__init__.py @@ -3,6 +3,7 @@ from .commonline_utils import ( build_outer_products, g_sync, + saff_kuijlaars, ) from .commonline_base import Orient3D from .commonline_matrix import CLOrient3D diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py index 34f28ee775..13ec282346 100644 --- a/src/aspire/abinitio/commonline_d2.py +++ b/src/aspire/abinitio/commonline_d2.py @@ -10,7 +10,7 @@ from aspire.utils.random import randn from aspire.volume import DnSymmetryGroup -from .commonline_utils import _generate_shift_phase_and_filter +from .commonline_utils import _generate_shift_phase_and_filter, saff_kuijlaars logger = logging.getLogger(__name__) @@ -160,7 +160,7 @@ def _generate_lookup_data(self): logger.info("Generating commonline lookup data.") # Generate uniform grid on sphere with Saff-Kuijlaars and take one quarter # of sphere because of D2 symmetry redundancy. - sphere_grid = self._saff_kuijlaars(self.grid_res) + sphere_grid = saff_kuijlaars(self.grid_res) octant1_mask = np.all(sphere_grid > 0, axis=1) octant2_mask = ( (sphere_grid[:, 0] > 0) & (sphere_grid[:, 1] > 0) & (sphere_grid[:, 2] < 0) @@ -1822,37 +1822,6 @@ def _circ_seq(n1, n2, L): return seq - @staticmethod - def _saff_kuijlaars(N): - """ - Generates N vertices on the unit sphere that are approximately evenly distributed. - - This implements the recommended algorithm in spherical coordinates - (theta, phi) according to "Distributing many points on a sphere" - by E.B. Saff and A.B.J. Kuijlaars, Mathematical Intelligencer 19.1 - (1997) 5--11. - - :param N: Number of vertices to generate. - - :return: Nx3 array of vertices in cartesian coordinates. - """ - k = np.arange(1, N + 1) - h = -1 + 2 * (k - 1) / (N - 1) - theta = np.arccos(h) - phi = np.zeros(N) - - for i in range(1, N - 1): - phi[i] = (phi[i - 1] + 3.6 / (np.sqrt(N * (1 - h[i] ** 2)))) % (2 * np.pi) - - # Spherical coordinates - x = np.sin(theta) * np.cos(phi) - y = np.sin(theta) * np.sin(phi) - z = np.cos(theta) - - mesh = np.column_stack((x, y, z)) - - return mesh - @staticmethod def _mark_equators(sphere_grid, eq_filter_angle): """ diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index 23624aae1f..70cee9ccb0 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -12,6 +12,8 @@ from aspire.utils import Rotation, cart2sph from aspire.volume import CnSymmetryGroup, DnSymmetryGroup, SymmetryGroup +from .commonline_utils import saff_kuijlaars + logger = logging.getLogger(__name__) @@ -37,7 +39,7 @@ def __init__( ratio=1, factor=1.0, mult=1.5, - Ngrid=16317, + S2_grid=441, Nstep_yI=10, perform_pr=False, **kwargs, @@ -69,7 +71,7 @@ def __init__( self.ratio = ratio self.factor = factor self.mult = mult - self.Ngrid = Ngrid + self.S2_grid = S2_grid self.Nstep_yI = Nstep_yI self.perform_pr = perform_pr @@ -102,7 +104,6 @@ def estimate_rotations(self): C, self.Lmax, self.n_img, - self.Ngrid, self.max_iter, self.rho, self.ratio, @@ -122,7 +123,6 @@ def estimate_rotations(self): weight, Penalty, r, - self.Ngrid, self.max_iter, self.rho, self.ratio, @@ -336,7 +336,6 @@ def admm_sym_J( C, Lmax, N, - Ngrid, max_iter, rho, ratio, @@ -371,9 +370,10 @@ def admm_sym_J( S0, S1, Sq, - ) = self.ADMM_preprocessing(C, Lmax, N, Ngrid) + ) = self.ADMM_preprocessing(C, Lmax, N) n_pairs = N * (N - 1) // 2 + Ngrid = self.Ngrid rank_Ak, _ = self.compute_rank(Lmax) logger.info(f"Rank of Ak: {rank_Ak}") @@ -741,7 +741,7 @@ def print_updates(verbose=True): X_admm[k] = xp.asnumpy(X_admm[k]) return X_admm - def ADMM_preprocessing(self, C, Lmax, N, Ngrid): + def ADMM_preprocessing(self, C, Lmax, N): # compute necessary quantities for ADMM # compute some useful index sets count = 0 @@ -799,8 +799,7 @@ def ADMM_preprocessing(self, C, Lmax, N, Ngrid): bEq = xp.repeat(bEq[:, xp.newaxis], N * (N - 1) // 2, axis=1) # AI and bI - W0, W1 = self.compute_fejer_weights() - + W0, W1, Ngrid = self.compute_fejer_weights() # AI_mat=np.zeros((Ngrid,D0+D1)) # for p in range(Ngrid): # w0=np.zeros(D0); w1=np.zeros(D1) @@ -811,7 +810,6 @@ def ADMM_preprocessing(self, C, Lmax, N, Ngrid): # AI_mat[p,:d0[-1]]=w0; AI_mat[p,d0[-1]:]=w1 # AI_mat=xp.asarray(AI_mat) / 10; # bI=-(Lmax+2)*(Lmax+1)/2 / 10 - AI_mat_offdiag = np.zeros((Ngrid, D0 + D1)) for p in range(Ngrid): w0 = np.zeros(D0) @@ -886,6 +884,7 @@ def ADMM_preprocessing(self, C, Lmax, N, Ngrid): Sq = xp.zeros(Xq.shape) # S0=xp.zeros(X0.shape); S1=xp.zeros(X1.shape); Sq=xp.zeros(Xq.shape) + self.Ngrid = Ngrid # return C0,C1,normC,AEq,bEq,AEqAEqtinv,AI_mat,bI,Lambda,d0,d1,D0,D1,idx_diag,idx_offdiag,IDX_upper,IDX_lower,X0,X1,Xq,S0,S1,Sq return ( C0, @@ -951,10 +950,11 @@ def permutek_block(Ak, k): W0.append(W0k) W1.append(W1k) - return W0, W1 + return W0, W1, Ngrid def discretize_SO3(self): - S2 = loadmat("design20.mat")["design"] + # S2 = loadmat("design20.mat")["design"] + S2 = saff_kuijlaars(self.S2_grid) S2_size = S2.shape[0] # discretize S1 @@ -988,7 +988,6 @@ def proximal_refine( weight, Penalty, r, - Ngrid, max_iter, rho, ratio, @@ -1052,7 +1051,6 @@ def low_rank_proj(X, r_step): CC, self.Lmax, N, - Ngrid, max_iter, rho, ratio, diff --git a/src/aspire/abinitio/commonline_utils.py b/src/aspire/abinitio/commonline_utils.py index 46a3a932c0..5fe88a21b1 100644 --- a/src/aspire/abinitio/commonline_utils.py +++ b/src/aspire/abinitio/commonline_utils.py @@ -288,6 +288,37 @@ def _cl_angles_to_ind(cl_angles, n_theta): return ind +def saff_kuijlaars(N): + """ + Generates N vertices on the unit sphere that are approximately evenly distributed. + + This implements the recommended algorithm in spherical coordinates + (theta, phi) according to "Distributing many points on a sphere" + by E.B. Saff and A.B.J. Kuijlaars, Mathematical Intelligencer 19.1 + (1997) 5--11. + + :param N: Number of vertices to generate. + + :return: Nx3 array of vertices in cartesian coordinates. + """ + k = np.arange(1, N + 1) + h = -1 + 2 * (k - 1) / (N - 1) + theta = np.arccos(h) + phi = np.zeros(N) + + for i in range(1, N - 1): + phi[i] = (phi[i - 1] + 3.6 / (np.sqrt(N * (1 - h[i] ** 2)))) % (2 * np.pi) + + # Spherical coordinates + x = np.sin(theta) * np.cos(phi) + y = np.sin(theta) * np.sin(phi) + z = np.cos(theta) + + mesh = np.column_stack((x, y, z)) + + return mesh + + def g_sync(rots, order, rots_gt): """ Given ground truth rotations, synchronize estimated rotations over From fdc6dc5afe7b96e67b0fcdf053914acb287103cd Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 15 May 2026 10:41:14 -0400 Subject: [PATCH 21/38] initial test file --- src/aspire/abinitio/commonline_nug.py | 2 +- tests/test_nug.py | 97 +++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 1 deletion(-) create mode 100644 tests/test_nug.py diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index 70cee9ccb0..c3bae03b70 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -28,7 +28,7 @@ def __init__( symmetry=None, n_rad=None, n_theta=360, - max_shift=0.15, + max_shift=0, shift_step=1, mask=True, Lmax=12, diff --git a/tests/test_nug.py b/tests/test_nug.py new file mode 100644 index 0000000000..4affb492e7 --- /dev/null +++ b/tests/test_nug.py @@ -0,0 +1,97 @@ +import numpy as np +import pytest + +from aspire.abinitio import CommonlineNUG +from aspire.source import Simulation +from aspire.utils import ( + J_conjugate, + Random, + Rotation, + all_pairs, + mean_aligned_angular_distance, + utest_tolerance, +) +from aspire.volume import CnSymmetricVolume, CnSymmetryGroup + +DTYPE = [np.float32] +RESOLUTION = [48, 49] +N_IMG = [5] +OFFSETS = [0] +ORDER = [3, 4] +PR = [False] +SEED = 1980 + + +@pytest.fixture(params=DTYPE, ids=lambda x: f"dtype={x}", scope="module") +def dtype(request): + return request.param + + +@pytest.fixture(params=RESOLUTION, ids=lambda x: f"resolution={x}", scope="module") +def resolution(request): + return request.param + + +@pytest.fixture(params=N_IMG, ids=lambda x: f"n images={x}", scope="module") +def n_img(request): + return request.param + + +@pytest.fixture(params=OFFSETS, ids=lambda x: f"offsets={x}", scope="module") +def offsets(request): + return request.param + + +@pytest.fixture(params=ORDER, ids=lambda x: f"order={x}", scope="module") +def order(request): + return request.param + + +@pytest.fixture(params=PR, ids=lambda x: f"proximal_refine={x}", scope="module") +def proximal_refine(request): + return request.param + + +############ +# Fixtures # +############ + +@pytest.fixture(scope="module") +def source(n_img, resolution, dtype, offsets, order): + vol = CnSymmetricVolume( + L=resolution, order=order, C=1, K=100, dtype=dtype, seed=SEED + ).generate() + + src = Simulation( + n=n_img, + L=resolution, + vols=vol, + offsets=offsets, + amplitudes=1, + seed=SEED, + ) + src = src.cache() # Precompute image stack + + return src + +@pytest.fixture(scope="module") +def orient_est(src, proximal_refine): + orient_est = CommonlineNUG( + src, + perform_pr=proximal_refine, + ) + + return orient_est + + +######### +# Tests # +######### + + +def test_dtypes(orient_est): + """ + Check dtypes for each major step of the algorithm. + """ + pass + From 636a3e4e65fe5d015b42328b99c1302647eb86c6 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 15 May 2026 14:30:08 -0400 Subject: [PATCH 22/38] clean up estimate_rotations --- src/aspire/abinitio/commonline_nug.py | 160 +++++++++++--------------- tests/test_nug.py | 12 +- 2 files changed, 71 insertions(+), 101 deletions(-) diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index c3bae03b70..f64a4f2689 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -42,6 +42,7 @@ def __init__( S2_grid=441, Nstep_yI=10, perform_pr=False, + verbose=True, **kwargs, ): """ @@ -74,6 +75,7 @@ def __init__( self.S2_grid = S2_grid self.Nstep_yI = Nstep_yI self.perform_pr = perform_pr + self.verbose = verbose # Handle symmetry if symmetry is None: @@ -98,56 +100,24 @@ def _build_full_pft(self): self.pf_full = PolarFT.half_to_full(pf) def estimate_rotations(self): - imgs = self.src.images[:] - C = self.compute_coeff(imgs, self.loss, self.Lmax, T=self.T) - X_est = self.admm_sym_J( - C, - self.Lmax, - self.n_img, - self.max_iter, - self.rho, - self.ratio, - self.factor, - self.mult, - self.Nstep_yI, - ) - - if self.perform_pr: - weight = 1 / (1 + np.arange(self.Lmax)) - Penalty = [1, 1, 1, 1] - r = [3, 2, 1, 0] - X_est = self.proximal_refine( - X_est, - C, - self.n_img, - weight, - Penalty, - r, - self.max_iter, - self.rho, - self.ratio, - self.factor, - self.mult, - self.Nstep_yI, - ) - - if isinstance(self.sym_grp, CnSymmetryGroup): - R_est, Euler_est = self.euler_est(X_est[0], X_est[self.n_sym - 1]) - elif isinstance(self.sym_grp, DnSymmetryGroup): - R_est, Euler_est = self.euler_est_Dm(X_est) - - self.rotations = R_est - - return R_est + self.compute_coeff() + self.perform_admm() + self.euler_est() + return self.rotations ####################### # Compute Coeffs Step # ####################### - def compute_coeff(self, Img, loss, Lmax, T): + def compute_coeff(self): # compute the coefficient matrix - N, L, _ = Img.shape - n_theta = 360 + Img = self.src.images[:] + N = self.n_img + L = self.src.L + n_theta = self.n_theta + loss = self.loss + Lmax = self.Lmax + T = self.T angular_sampling = np.arange(0, 360, 1) # Using ASPIRE Image.project(). Leaving original method in comments for now. @@ -264,7 +234,7 @@ def fijhat_k(k, F): ) C[k - 1] = np.round(C[k - 1], 10) - return C + self.C = C @staticmethod def fast_radon_transform(array, angles, use_ramp=False): @@ -331,19 +301,31 @@ def complex2real(ell): # ADMM Step # ############# - def admm_sym_J( - self, - C, - Lmax, - N, - max_iter, - rho, - ratio, - factor, - mult=1, - Nstep_yI=20, - verbose=True, - ): + def perform_admm(self): + X_est = self.admm_sym_J(self.C, self.verbose) + + if self.perform_pr: + weight = 1 / (1 + np.arange(self.Lmax)) + Penalty = [1, 1, 1, 1] + r = [3, 2, 1, 0] + X_est = self.proximal_refine( + X_est, + weight, + Penalty, + r, + ) + self.X_est = X_est + + def admm_sym_J(self, C, verbose): + Lmax = self.Lmax + N = self.n_img + max_iter = self.max_iter + rho = self.rho + ratio = self.ratio + factor = self.factor + mult = self.mult + Nstep_yI = self.Nstep_yI + # admm for symmetric case ( C0, @@ -370,7 +352,7 @@ def admm_sym_J( S0, S1, Sq, - ) = self.ADMM_preprocessing(C, Lmax, N) + ) = self.ADMM_preprocessing(C) n_pairs = N * (N - 1) // 2 Ngrid = self.Ngrid @@ -598,7 +580,7 @@ def update_rho(X0, X1, Xd0, Xd1, Xq, bE, bEq, bI, res_X, rho, factor, normC): rho = rho / factor return rho, p_resnorm, d_resnorm - def print_updates(verbose=True): + def print_updates(verbose): # X_admm=transform_coeff_back(X0,X1,Lmax,N); obj_p=0 # for k in range(Lmax): obj_p+=xp.trace(C[k]@X_admm[k]) if verbose: @@ -741,9 +723,11 @@ def print_updates(verbose=True): X_admm[k] = xp.asnumpy(X_admm[k]) return X_admm - def ADMM_preprocessing(self, C, Lmax, N): + def ADMM_preprocessing(self, C): # compute necessary quantities for ADMM # compute some useful index sets + Lmax = self.Lmax + N = self.n_img count = 0 idx_diag = [] idx_offdiag = [] @@ -980,22 +964,17 @@ def discretize_SO3(self): # Proximal Refinement Step # ############################ - def proximal_refine( - self, - X_admm, - C, - N, - weight, - Penalty, - r, - max_iter, - rho, - ratio, - factor, - mult, - Nstep_yI, - verbose=True, - ): + def proximal_refine(self, X_admm, weight, Penalty, r): + Lmax = self.Lmax + N = self.n_img + max_iter = self.max_iter + rho = self.rho + ratio = self.ratio + factor = self.factor + mult = self.mult + Nstep_yI = self.Nstep_yI + C = self.C + def Ak(J, Euler): # compute Ak matrix order = Euler.shape[0] @@ -1047,20 +1026,9 @@ def low_rank_proj(X, r_step): - Penalty[step] * weight[k] * (X_proj[k] + X_proj[k].T) / 2 ) - X_next = self.admm_sym_J( - CC, - self.Lmax, - N, - max_iter, - rho, - ratio, - factor, - mult, - Nstep_yI, - verbose=False, - ) + X_next = self.admm_sym_J(CC, verbose=False) - if verbose: + if self.verbose: logger.info( "Proximal refine step %d/%d: relative update %.3e", step + 1, @@ -1076,7 +1044,17 @@ def low_rank_proj(X, r_step): # Euler Estimation Step # ######################### - def euler_est(self, X1, XS): + def euler_est(self): + X_est = self.X_est + if isinstance(self.sym_grp, CnSymmetryGroup): + R_est, Euler_est = self.euler_est_Cm(X_est[0], X_est[self.n_sym - 1]) + elif isinstance(self.sym_grp, DnSymmetryGroup): + R_est, Euler_est = self.euler_est_Dm(X_est) + + self.Euler_est = Euler_est + self.rotations = R_est + + def euler_est_Cm(self, X1, XS): S = self.n_sym N = self.n_img sym_euler = np.zeros((S, 3)) diff --git a/tests/test_nug.py b/tests/test_nug.py index 4affb492e7..a221f36bd7 100644 --- a/tests/test_nug.py +++ b/tests/test_nug.py @@ -3,15 +3,6 @@ from aspire.abinitio import CommonlineNUG from aspire.source import Simulation -from aspire.utils import ( - J_conjugate, - Random, - Rotation, - all_pairs, - mean_aligned_angular_distance, - utest_tolerance, -) -from aspire.volume import CnSymmetricVolume, CnSymmetryGroup DTYPE = [np.float32] RESOLUTION = [48, 49] @@ -56,6 +47,7 @@ def proximal_refine(request): # Fixtures # ############ + @pytest.fixture(scope="module") def source(n_img, resolution, dtype, offsets, order): vol = CnSymmetricVolume( @@ -74,6 +66,7 @@ def source(n_img, resolution, dtype, offsets, order): return src + @pytest.fixture(scope="module") def orient_est(src, proximal_refine): orient_est = CommonlineNUG( @@ -94,4 +87,3 @@ def test_dtypes(orient_est): Check dtypes for each major step of the algorithm. """ pass - From 98d1c874f839bb72d3f754238d994f21a40ea2ac Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 19 May 2026 16:23:45 -0400 Subject: [PATCH 23/38] Explicit dtypes. Add dtype test. Add estimate_rotations test. --- src/aspire/abinitio/commonline_nug.py | 271 ++++++++++++++++---------- tests/test_nug.py | 66 ++++++- 2 files changed, 222 insertions(+), 115 deletions(-) diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index f64a4f2689..944394c5a4 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -2,14 +2,13 @@ import time import numpy as np -from scipy.io import loadmat from scipy.special import factorial from aspire.abinitio import Orient3D from aspire.nufft import nufft from aspire.numeric import fft, xp from aspire.operators import PolarFT, wemd_embed -from aspire.utils import Rotation, cart2sph +from aspire.utils import Rotation, cart2sph, complex_type from aspire.volume import CnSymmetryGroup, DnSymmetryGroup, SymmetryGroup from .commonline_utils import saff_kuijlaars @@ -113,12 +112,11 @@ def compute_coeff(self): # compute the coefficient matrix Img = self.src.images[:] N = self.n_img - L = self.src.L n_theta = self.n_theta loss = self.loss Lmax = self.Lmax T = self.T - angular_sampling = np.arange(0, 360, 1) + angular_sampling = np.arange(0, 360, 1, dtype=np.float64) # Using ASPIRE Image.project(). Leaving original method in comments for now. line_proj = Img.project(angular_sampling).asnumpy().T @@ -130,13 +128,6 @@ def compute_coeff(self): # line_proj[:, :, n], Img_pft[:, :, n] = self.fast_radon_transform( # Img[n], angular_sampling # ) - - dim_wave = len(wemd_embed(line_proj[:, 0, 0])) - WE = np.zeros((dim_wave, n_theta, N)) - for i in range(N): - for theta in range(n_theta): - WE[:, theta, i] = wemd_embed(line_proj[:, theta, i]) - def fij(alpha, gamma, i, j, loss): if loss == "l1": # Ii_hat = Img_pft[:, :, i] @@ -160,6 +151,12 @@ def fij(alpha, gamma, i, j, loss): return np.linalg.norm(Si - Sj, 1) if loss == "wemd": + dim_wave = len(wemd_embed(line_proj[:, 0, 0])) + WE = np.zeros((dim_wave, n_theta, N), dtype=np.float64) + for i in range(N): + for theta in range(n_theta): + WE[:, theta, i] = wemd_embed(line_proj[:, theta, i]) + idxi = np.round((alpha - np.pi / 2) * n_theta / 2 / np.pi) % n_theta idxj = np.round((-gamma - np.pi / 2) * n_theta / 2 / np.pi) % n_theta @@ -167,11 +164,11 @@ def fij(alpha, gamma, i, j, loss): Sj = WE[:, int(idxj), j] return np.linalg.norm(Si - Sj, 1) - alpha_grid = np.arange(2 * T) * np.pi / T - beta_grid = (2 * np.arange(2 * T) + 1) * np.pi / 4 / T - gamma_grid = np.arange(2 * T) * np.pi / T + alpha_grid = np.arange(2 * T, dtype=np.float64) * np.pi / T + beta_grid = (2 * np.arange(2 * T, dtype=np.float64) + 1) * np.pi / 4 / T + gamma_grid = np.arange(2 * T, dtype=np.float64) * np.pi / T - bT = np.zeros(2 * T) + bT = np.zeros(2 * T, dtype=np.float64) for n in range(2 * T): ss = 0 for m in range(T): @@ -186,11 +183,11 @@ def fij(alpha, gamma, i, j, loss): def fijhat_k(k, F): dk = 2 * k + 1 - exp_alpha_grid = np.zeros((2 * T, dk), dtype=complex) + exp_alpha_grid = np.zeros((2 * T, dk), dtype=complex_type(np.float64)) for m in range(-k, k + 1): exp_alpha_grid[:, m + k] = np.exp(1j * m * alpha_grid) - exp_gamma_grid = np.zeros((2 * T, dk), dtype=complex) + exp_gamma_grid = np.zeros((2 * T, dk), dtype=complex_type(np.float64)) for m in range(-k, k + 1): exp_gamma_grid[:, m + k] = np.exp(1j * m * gamma_grid) @@ -201,10 +198,10 @@ def fijhat_k(k, F): C = [] for k in range(1, Lmax + 1): dk = 2 * k + 1 - C.append(np.zeros((N * dk, N * dk), dtype=complex)) + C.append(np.zeros((N * dk, N * dk), dtype=complex_type(np.float64))) for i in range(N): for j in range(i + 1, N): - Fij = np.zeros((2 * T, 2 * T)) + Fij = np.zeros((2 * T, 2 * T), dtype=np.float64) for j1 in range(2 * T): for j2 in range(2 * T): Fij[j1, j2] = fij(alpha_grid[j1], gamma_grid[j2], i, j, loss) @@ -217,7 +214,7 @@ def fijhat_k(k, F): C[k - 1] = C[k - 1] + C[k - 1].conj().T for i in range(N): - Fii = np.zeros((2 * T, 2 * T)) + Fii = np.zeros((2 * T, 2 * T), dtype=np.float64) for j1 in range(2 * T): for j2 in range(2 * T): Fii[j1, j2] = fij(alpha_grid[j1], gamma_grid[j2], i, i, loss) @@ -230,9 +227,10 @@ def fijhat_k(k, F): for k in range(1, Lmax + 1): [T, Tinv] = self.complex2real(k) C[k - 1] = np.real( - np.kron(np.eye(N), Tinv) @ C[k - 1] @ np.kron(np.eye(N), T) + np.kron(np.eye(N, dtype=np.float64), Tinv) + @ C[k - 1] + @ np.kron(np.eye(N, dtype=np.float64), T) ) - C[k - 1] = np.round(C[k - 1], 10) self.C = C @@ -270,11 +268,10 @@ def fast_radon_transform(array, angles, use_ramp=False): return projections, lines_f - @staticmethod - def complex2real(ell): + def complex2real(self, ell): # compute transformation matrices that convert complex representations to real ones diml = 2 * ell + 1 - Tinv = np.zeros((diml, diml), dtype=complex) + Tinv = np.zeros((diml, diml), dtype=complex_type(np.float64)) for i in range(diml): if i < ell: Tinv[i, i] = 1j / np.sqrt(2) @@ -285,7 +282,7 @@ def complex2real(ell): Tinv[i, i] = (-1) ** (i - ell) / np.sqrt(2) Tinv[i, diml - 1 - i] = 1 / np.sqrt(2) - T = np.zeros((diml, diml), dtype=complex) + T = np.zeros((diml, diml), dtype=complex_type(np.float64)) for i in range(diml): if i < ell: T[i, i] = -1j / np.sqrt(2) @@ -367,9 +364,9 @@ def admm_sym_J(self, C, verbose): for k in range(1, Lmax + 1): s0 = k**2 s1 = (k + 1) ** 2 - AEk = xp.zeros((1 + s0 + s1, 2 * (s0 + s1))) - AEk[0, :s0] = xp.eye(k).T.reshape(-1) - AEk[0, s0 : s0 + s1] = xp.eye(k + 1).T.reshape(-1) + AEk = xp.zeros((1 + s0 + s1, 2 * (s0 + s1)), dtype=np.float64) + AEk[0, :s0] = xp.eye(k, dtype=np.float64).T.reshape(-1) + AEk[0, s0 : s0 + s1] = xp.eye(k + 1, dtype=np.float64).T.reshape(-1) for count in range(1, 1 + s0): AEk[count, count - 1] = 1 AEk[count, count - 1 + s0 + s1] = 1 @@ -378,7 +375,7 @@ def admm_sym_J(self, C, verbose): AEk[count, count - 1 + s1 + s0] = 1 AE.append(AEk) AEAETinv.append(np.linalg.pinv(AEk @ AEk.T)) - bE = xp.zeros((Lmax + D0 + D1)) + bE = xp.zeros((Lmax + D0 + D1), dtype=np.float64) for k in range(Lmax): bE[k + d0[k] + d1[k] :] = rank_Ak[k] bE[k + 1 + d0[k] + d1[k] : k + 1 + d0[k + 1] + d1[k]] = xp.eye( @@ -391,7 +388,7 @@ def admm_sym_J(self, C, verbose): P = [] for k in range(1, Lmax + 1): dk = 2 * k + 1 - Pk = xp.eye(dk) + Pk = xp.eye(dk, dtype=np.float64) for m in range(k): for el in range(k - m): Pk[(m + 2 * el, m + 2 * el + 1), :] = Pk[ @@ -400,7 +397,7 @@ def admm_sym_J(self, C, verbose): P.append(Pk) def fun_AE(X0, X1, Xd0, Xd1, Xq): - z = xp.zeros((Lmax + D0 + D1, N)) + z = xp.zeros((Lmax + D0 + D1, N), dtype=np.float64) for k in range(Lmax): z[k + d0[k] + d1[k] : k + 1 + d0[k + 1] + d1[k + 1]] = AE[ k @@ -418,10 +415,10 @@ def fun_AE(X0, X1, Xd0, Xd1, Xq): ) def fun_AET(yE, yEq): - Z0 = xp.zeros((D0, N * (N + 1) // 2)) - Z1 = xp.zeros((D1, N * (N + 1) // 2)) - Zd0 = xp.zeros((D0, N)) - Zd1 = xp.zeros((D1, N)) + Z0 = xp.zeros((D0, N * (N + 1) // 2), dtype=np.float64) + Z1 = xp.zeros((D1, N * (N + 1) // 2), dtype=np.float64) + Zd0 = xp.zeros((D0, N), dtype=np.float64) + Zd1 = xp.zeros((D1, N), dtype=np.float64) for k in range(Lmax): s0 = (k + 1) ** 2 s1 = (k + 2) ** 2 @@ -436,7 +433,7 @@ def fun_AET(yE, yEq): return Z0, Z1, Zd0, Zd1, Zq[:16] def fun_AI(X0, X1): - z = xp.zeros((Ngrid, N * (N + 1) // 2)) + z = xp.zeros((Ngrid, N * (N + 1) // 2), dtype=np.float64) tmp = xp.concatenate((X0, X1), axis=0) z[:, idx_diag] = AI_mat_diag @ tmp[:, idx_diag] z[:, idx_offdiag] = AI_mat_offdiag @ tmp[:, idx_offdiag] @@ -446,7 +443,7 @@ def fun_AI(X0, X1): def fun_AIT(yI): # Z=AI_mat.T@yI # return Z[:D0,:], Z[D0:,:] - Z = xp.zeros((D0 + D1, N * (N + 1) // 2)) + Z = xp.zeros((D0 + D1, N * (N + 1) // 2), dtype=np.float64) Z[:, idx_diag] = AI_mat_diag.T @ yI[:, idx_diag] Z[:, idx_offdiag] = AI_mat_offdiag.T @ yI[:, idx_offdiag] return Z[:D0, :], Z[D0:, :] @@ -516,7 +513,7 @@ def update_yE(C0, C1, X0, X1, Xd0, Xd1, Xq, S0, S1, Sd0, Sd1, Sq, yI, rho): -Xd1 / rho - Sd1, -Xq / rho - Sq, ) - yE = xp.zeros((Lmax + D0 + D1, N)) + yE = xp.zeros((Lmax + D0 + D1, N), dtype=np.float64) for k in range(Lmax): yE[k + d0[k] + d1[k] : k + 1 + d0[k + 1] + d1[k + 1]] = AEAETinv[k] @ ( bE[k + d0[k] + d1[k] : k + 1 + d0[k + 1] + d1[k + 1]] / rho @@ -680,13 +677,13 @@ def print_updates(verbose): + ", |X|=%1.2f" % normX ) - Xd0 = xp.zeros((D0, N)) - Xd1 = xp.zeros((D1, N)) - Sd0 = xp.zeros(Xd0.shape) - Sd1 = xp.zeros(Xd1.shape) - yI = xp.zeros((Ngrid, N * (N + 1) // 2)) - yE = xp.zeros(bE.shape) - yEq = xp.zeros(bEq.shape) + Xd0 = xp.zeros((D0, N), dtype=np.float64) + Xd1 = xp.zeros((D1, N), dtype=np.float64) + Sd0 = xp.zeros(Xd0.shape, dtype=np.float64) + Sd1 = xp.zeros(Xd1.shape, dtype=np.float64) + yI = xp.zeros((Ngrid, N * (N + 1) // 2), dtype=np.float64) + yE = xp.zeros(bE.shape, dtype=np.float64) + yEq = xp.zeros(bEq.shape, dtype=np.float64) IDX = np.arange(3) Time = np.zeros(4) @@ -774,11 +771,13 @@ def ADMM_preprocessing(self, C): D1 = d1[-1] # AE and bE for quaternion constraints - AEq = xp.asarray(loadmat("data/Eq_constraints/AEqJ.mat")["AEq"]) - AEqAEqtinv = xp.asarray(loadmat("data/Eq_constraints/AEqJ.mat")["AEqAEqtinv"]) + # AEq = xp.asarray(loadmat("data/Eq_constraints/AEqJ.mat")["AEq"]) + # AEqAEqtinv = xp.asarray(loadmat("data/Eq_constraints/AEqJ.mat")["AEqAEqtinv"]) + AEq = xp.asarray(self.construct_AEq()) + AEqAEqtinv = xp.linalg.pinv(AEq @ AEq.T) - bEq = xp.zeros(17) - bEq[:16] = xp.eye(4).reshape(-1) / 4 + bEq = xp.zeros(17, dtype=np.float64) + bEq[:16] = xp.eye(4, dtype=np.float64).reshape(-1) / 4 bEq[-1] = 1 bEq = xp.repeat(bEq[:, xp.newaxis], N * (N - 1) // 2, axis=1) @@ -794,10 +793,10 @@ def ADMM_preprocessing(self, C): # AI_mat[p,:d0[-1]]=w0; AI_mat[p,d0[-1]:]=w1 # AI_mat=xp.asarray(AI_mat) / 10; # bI=-(Lmax+2)*(Lmax+1)/2 / 10 - AI_mat_offdiag = np.zeros((Ngrid, D0 + D1)) + AI_mat_offdiag = np.zeros((Ngrid, D0 + D1), dtype=np.float64) for p in range(Ngrid): - w0 = np.zeros(D0) - w1 = np.zeros(D1) + w0 = np.zeros(D0, dtype=np.float64) + w1 = np.zeros(D1, dtype=np.float64) for k in range(1, Lmax + 1): w0[d0[k - 1] : d0[k]] = ( (Lmax - k + 2) @@ -826,10 +825,10 @@ def ADMM_preprocessing(self, C): # AI_mat_offdiag_new[:, d0[k - 1]:d0[k]] = block0 # AI_mat_offdiag_new[:, d0[-1] + d1[k - 1]:d0[-1] + d1[k]] = block1 - AI_mat_diag = np.zeros((Ngrid, D0 + D1)) + AI_mat_diag = np.zeros((Ngrid, D0 + D1), dtype=np.float64) for p in range(Ngrid): - w0 = np.zeros(D0) - w1 = np.zeros(D1) + w0 = np.zeros(D0, dtype=np.float64) + w1 = np.zeros(D1, dtype=np.float64) for k in range(1, Lmax + 1): w0[d0[k - 1] : d0[k]] = ( (Lmax - k + 2) @@ -857,15 +856,15 @@ def ADMM_preprocessing(self, C): II = [] for k in range(1, Lmax + 1): dk = 2 * k + 1 - II.append(xp.eye(N * dk)) + II.append(xp.eye(N * dk, dtype=np.float64)) I0, I1 = self.transform_coeff(II, Lmax, N, IDX_upper) - X0 = xp.zeros((D0, N * (N + 1) // 2)) - X1 = xp.zeros((D1, N * (N + 1) // 2)) - Xq = xp.zeros((16, N * (N - 1) // 2)) + X0 = xp.zeros((D0, N * (N + 1) // 2), dtype=np.float64) + X1 = xp.zeros((D1, N * (N + 1) // 2), dtype=np.float64) + Xq = xp.zeros((16, N * (N - 1) // 2), dtype=np.float64) # X0,X1=transform_coeff(II,Lmax,N); Xq=xp.zeros((16,N*(N-1))); Xq[0,:]=1; S0 = xp.copy(I0) S1 = xp.copy(I1) - Sq = xp.zeros(Xq.shape) + Sq = xp.zeros(Xq.shape, dtype=np.float64) # S0=xp.zeros(X0.shape); S1=xp.zeros(X1.shape); Sq=xp.zeros(Xq.shape) self.Ngrid = Ngrid @@ -911,7 +910,7 @@ def compute_fejer_weights(self): def permutek_block(Ak, k): dk = 2 * k + 1 - Pk = np.eye(dk) + Pk = np.eye(dk, dtype=np.float64) for m in range(k): for ell in range(k - m): Pk[(m + 2 * ell, m + 2 * ell + 1), :] = Pk[ @@ -923,8 +922,8 @@ def permutek_block(Ak, k): W0 = [] W1 = [] for k in range(start, self.Lmax + 1): - W0k = np.zeros((Ngrid, k, k)) - W1k = np.zeros((Ngrid, k + 1, k + 1)) + W0k = np.zeros((Ngrid, k, k), dtype=np.float64) + W1k = np.zeros((Ngrid, k + 1, k + 1), dtype=np.float64) TkT = TT[k - start].T TinvkT = TTI[k - start].T @@ -951,7 +950,7 @@ def discretize_SO3(self): gamma = gamma + np.pi # SO(3) in Euler ZYZ - SO3 = np.zeros((S2_size * S1_size, 3)) + SO3 = np.zeros((S2_size * S1_size, 3), dtype=np.float64) count = 0 for i in range(S1_size): for j in range(S2_size): @@ -965,14 +964,7 @@ def discretize_SO3(self): ############################ def proximal_refine(self, X_admm, weight, Penalty, r): - Lmax = self.Lmax N = self.n_img - max_iter = self.max_iter - rho = self.rho - ratio = self.ratio - factor = self.factor - mult = self.mult - Nstep_yI = self.Nstep_yI C = self.C def Ak(J, Euler): @@ -989,7 +981,7 @@ def rel_change(A, B, eps=1e-12): den += np.linalg.norm(B[k]) ** 2 return np.sqrt(num) / max(np.sqrt(den), eps) - rank_Ak = np.zeros(self.Lmax) + rank_Ak = np.zeros(self.Lmax, dtype=np.float64) C_base = [None] * self.Lmax for k in range(self.Lmax): C_base[k] = xp.asnumpy(C[k]).copy() @@ -1052,18 +1044,26 @@ def euler_est(self): R_est, Euler_est = self.euler_est_Dm(X_est) self.Euler_est = Euler_est - self.rotations = R_est + self.rotations = R_est.astype(self.dtype) def euler_est_Cm(self, X1, XS): S = self.n_sym N = self.n_img - sym_euler = np.zeros((S, 3)) + sym_euler = np.zeros((S, 3), dtype=np.float64) for s in range(S): sym_euler[s] = [2 * np.pi * s / S, 0, 0] [T, Tinv] = self.complex2real(1) - X1 = np.kron(np.eye(N), T) @ X1 @ np.kron(np.eye(N), Tinv) + X1 = ( + np.kron(np.eye(N, dtype=np.float64), T) + @ X1 + @ np.kron(np.eye(N, dtype=np.float64), Tinv) + ) [T, Tinv] = self.complex2real(S) - XS = np.kron(np.eye(N), T) @ XS @ np.kron(np.eye(N), Tinv) + XS = ( + np.kron(np.eye(N, dtype=np.float64), T) + @ XS + @ np.kron(np.eye(N, dtype=np.float64), Tinv) + ) def find_phase(A, B): # find a number c that minimizes ||cA-B||_F @@ -1079,8 +1079,8 @@ def find_phase(A, B): return c / abs(c) def find_beta(X1): - B1 = np.zeros((N, N)) - B2 = np.zeros((N, N)) + B1 = np.zeros((N, N), dtype=np.float64) + B2 = np.zeros((N, N), dtype=np.float64) for i in range(N): for j in range(N): Xij = X1[3 * i : 3 * (i + 1), 3 * j : 3 * (j + 1)] @@ -1096,8 +1096,8 @@ def find_beta(X1): return beta def find_alpha(X1): - ZZbar = np.zeros((N, N), dtype=complex) - ZZ = np.zeros((N, N), dtype=complex) + ZZbar = np.zeros((N, N), dtype=complex_type(np.float64)) + ZZ = np.zeros((N, N), dtype=complex_type(np.float64)) for i in range(N): for j in range(N): z = X1[3 * i : 3 * (i + 1), 3 * j : 3 * (j + 1)][0, 0] @@ -1112,12 +1112,12 @@ def find_alpha(X1): c = find_phase(Z[:, None] @ Z[:, None].T, ZZ) Z = np.sqrt(c) * Z - return np.angle(Z) + return np.angle(Z).astype(np.float64) dk = 2 * S + 1 def find_gamma(Xm, beta, alpha): - C = np.zeros((N, N), dtype=complex) + C = np.zeros((N, N), dtype=complex_type(np.float64)) Jk = np.ones(dk) Jk[S + 1 :: 2] = -1 Jk[S - 1 :: -2] = -1 @@ -1150,13 +1150,13 @@ def find_gamma(Xm, beta, alpha): C[i, j] = np.vdot(C1 + C2, np.real(C3)) / np.vdot( C1 + C2, C1 + C2 ) + 1j * np.vdot(C1 - C2, np.imag(C3)) / np.vdot(C1 - C2, C1 - C2) - C += C.T.conj() + np.eye(N) + C += C.T.conj() + np.eye(N, dtype=np.float64) evals, evecs = np.linalg.eigh(C) idx = np.argmax(evals) c = evecs[:, idx] * np.sqrt(evals[idx]) return np.angle(c) / S - Euler_est = np.zeros((N, 3)) + Euler_est = np.zeros((N, 3), dtype=np.float64) Euler_est[:, 0] = find_alpha(X1) Euler_est[:, 1] = find_beta(X1) Euler_est[:, 2] = find_gamma(XS, Euler_est[:, 1], Euler_est[:, 0]) @@ -1186,9 +1186,13 @@ def find_phase(A, B): return c / abs(c) T, Tinv = self.complex2real(2) - X2 = np.kron(np.eye(N), T) @ X2 @ np.kron(np.eye(N), Tinv) + X2 = ( + np.kron(np.eye(N, dtype=np.float64), T) + @ X2 + @ np.kron(np.eye(N, dtype=np.float64), Tinv) + ) - B1 = np.zeros((N, N)) + B1 = np.zeros((N, N), dtype=np.float64) for i in range(N): for j in range(N): Xij = X2[5 * i : 5 * (i + 1), 5 * j : 5 * (j + 1)] @@ -1198,8 +1202,8 @@ def find_phase(A, B): b1 = v1[:, idx] * np.sqrt(e1[idx]) * np.sign(v1[0, idx]) beta_est = np.arccos(np.clip(np.sqrt(b1), -1, 1)) % np.pi - Aminus = np.zeros((N, N), dtype=complex) - Aplus = np.zeros((N, N), dtype=complex) + Aminus = np.zeros((N, N), dtype=complex_type(np.float64)) + Aplus = np.zeros((N, N), dtype=complex_type(np.float64)) for i in range(N): for j in range(N): Xij = X2[5 * i : 5 * (i + 1), 5 * j : 5 * (j + 1)] @@ -1253,8 +1257,12 @@ def LS_D(W1, W2, W3, W4, Br, Bi): return a + 1j * b [T, Tinv] = self.complex2real(S) - Xm = np.kron(np.eye(N), T) @ Xm @ np.kron(np.eye(N), Tinv) - C = np.zeros((N, N), dtype=complex) + Xm = ( + np.kron(np.eye(N, dtype=np.float64), T) + @ Xm + @ np.kron(np.eye(N, dtype=np.float64), Tinv) + ) + C = np.zeros((N, N), dtype=complex_type(np.float64)) Jk = np.ones(dk) Jk[S + 1 :: 2] = -1 Jk[S - 1 :: -2] = -1 @@ -1290,7 +1298,7 @@ def LS_D(W1, W2, W3, W4, Br, Bi): Br = np.real(4 * DXijmD) Bi = np.imag(4 * DXijmD) C[i, j] = LS_D(W1, W2, W3, W4, Br, Bi) - C += C.T.conj() + np.eye(N) + C += C.T.conj() + np.eye(N, dtype=np.float64) evals, evecs = np.linalg.eigh(C) idx = np.argmax(evals) c = evecs[:, idx] * np.sqrt(evals[idx]) @@ -1298,7 +1306,7 @@ def LS_D(W1, W2, W3, W4, Br, Bi): alpha_est, beta_est = find_alpha_beta(X2) gamma_est = find_gamma(XS, alpha_est, beta_est) - Euler_est = np.zeros((N, 3)) + Euler_est = np.zeros((N, 3), dtype=np.float64) Euler_est[:, 0] = alpha_est Euler_est[:, 1] = beta_est Euler_est[:, 2] = gamma_est @@ -1315,8 +1323,8 @@ def transform_coeff(self, A, Lmax, N, IDX_upper): for k in range(1, Lmax + 1): d0.append(d0[-1] + k**2) d1.append(d1[-1] + (k + 1) ** 2) - A0 = xp.zeros((d0[-1], N * (N + 1) // 2)) - A1 = xp.zeros((d1[-1], N * (N + 1) // 2)) + A0 = xp.zeros((d0[-1], N * (N + 1) // 2), dtype=np.float64) + A1 = xp.zeros((d1[-1], N * (N + 1) // 2), dtype=np.float64) for k in range(1, Lmax + 1): a0, a1 = self.permutek(A[k - 1], k, N) A0[d0[k - 1] : d0[k], :] = self.vec_block(a0, N, k, IDX_upper) @@ -1332,7 +1340,7 @@ def transform_coeff_back(self, A0, A1, Lmax, N, IDX_upper, IDX_lower, idx_offdia A = [] for k in range(1, Lmax + 1): dk = 2 * k + 1 - Ak = xp.zeros((N * dk, N * dk)) + Ak = xp.zeros((N * dk, N * dk), dtype=np.float64) Ak[: N * k, : N * k] = self.mat_block( A0[d0[k - 1] : d0[k], :], N, k, IDX_upper, IDX_lower, idx_offdiag ) @@ -1347,13 +1355,17 @@ def transform_coeff_back(self, A0, A1, Lmax, N, IDX_upper, IDX_lower, idx_offdia def permutek(Ak, k, N): AkP = xp.copy(Ak) dk = 2 * k + 1 - Pk = xp.eye(dk) + Pk = xp.eye(dk, dtype=AkP.dtype) for m in range(k): for n in range(k - m): Pk[(m + 2 * n, m + 2 * n + 1), :] = Pk[(m + 2 * n + 1, m + 2 * n), :] - AkP = xp.kron(xp.eye(N), Pk) @ Ak @ xp.kron(xp.eye(N), Pk.T) + AkP = ( + xp.kron(xp.eye(N, dtype=AkP.dtype), Pk) + @ Ak + @ xp.kron(xp.eye(N, dtype=AkP.dtype), Pk.T) + ) - Pk = xp.eye(N * dk) + Pk = xp.eye(N * dk, dtype=AkP.dtype) idx = xp.concatenate((xp.arange(dk - k, dk), xp.arange(k + 1))) for m in range(N - 1): for n in range(N - 1 - m): @@ -1366,7 +1378,7 @@ def permutek(Ak, k, N): @staticmethod def permutek_back(Ak, k, N): dk = 2 * k + 1 - Pk = xp.eye(N * dk) + Pk = xp.eye(N * dk, dtype=Ak.dtype) idx = xp.concatenate((xp.arange(dk - k, dk), xp.arange(k + 1))) for m in range(N - 1): for n in range(N - 1 - m): @@ -1375,11 +1387,15 @@ def permutek_back(Ak, k, N): ][idx, :] AkB = Pk.T @ Ak @ Pk dk = 2 * k + 1 - Pk = xp.eye(dk) + Pk = xp.eye(dk, dtype=Ak.dtype) for m in range(k): for n in range(k - m): Pk[(m + 2 * n, m + 2 * n + 1), :] = Pk[(m + 2 * n + 1, m + 2 * n), :] - AkB = xp.kron(xp.eye(N), Pk.T) @ AkB @ xp.kron(xp.eye(N), Pk) + AkB = ( + xp.kron(xp.eye(N, dtype=Ak.dtype), Pk.T) + @ AkB + @ xp.kron(xp.eye(N, dtype=Ak.dtype), Pk) + ) return AkB @staticmethod @@ -1404,7 +1420,7 @@ def largest_eigenvalue(AI, Ngrid, N): return Lambda def compute_rank(self, Lmax): - rk = xp.zeros(Lmax) + rk = xp.zeros(Lmax, dtype=np.float64) A = [] for k in range(1, Lmax + 1): Ak = np.sum(self.WD(k, self.sym_euler), axis=0) @@ -1430,7 +1446,7 @@ def WD(self, J, euler): @staticmethod def Wd(J, beta): # compute Wigner small d matrix - d = np.zeros((len(beta), 2 * J + 1, 2 * J + 1)) + d = np.zeros((len(beta), 2 * J + 1, 2 * J + 1), dtype=beta.dtype) for m in range(-J, J + 1): for n in range(-J, J + 1): smin = max(0, m - n) @@ -1457,7 +1473,7 @@ def Wd(J, beta): @staticmethod def mat_block(vecA, N, sz, IDX_upper, IDX_lower, idx_offdiag): tmp = vecA.T.reshape(N * (N + 1) // 2, sz, sz).transpose(0, 2, 1) - AA = xp.zeros((N**2, sz, sz)) + AA = xp.zeros((N**2, sz, sz), dtype=vecA.dtype) AA[IDX_upper] = tmp AA[IDX_lower] = tmp[idx_offdiag].transpose(0, 2, 1) return (AA.reshape(N, N, sz, sz).transpose(0, 2, 1, 3)).reshape(N * sz, N * sz) @@ -1478,7 +1494,7 @@ def transform_block(A, k, Pk=None): if Pk is None: dk = 2 * k + 1 - Pk = xp.eye(dk) + Pk = xp.eye(dk, dtype=A.dtype) for m in range(k): for el in range(k - m): Pk[(m + 2 * el, m + 2 * el + 1), :] = Pk[ @@ -1502,11 +1518,11 @@ def transform_back_block(A0, A1, k, Pk=None): A0 = A0[None, :] A1 = A1[None, :] - A = xp.zeros((A0.shape[0], dk, dk)) + A = xp.zeros((A0.shape[0], dk, dk), dtype=A0.dtype) A[:, :k, :k] = A0.reshape(-1, k, k).swapaxes(-1, -2) A[:, k:, k:] = A1.reshape(-1, k + 1, k + 1).swapaxes(-1, -2) if Pk is None: - Pk = xp.eye(dk) + Pk = xp.eye(dk, dtype=A0.dtype) for m in range(k): for el in range(k - m): Pk[(m + 2 * el, m + 2 * el + 1), :] = Pk[ @@ -1514,3 +1530,42 @@ def transform_back_block(A0, A1, k, Pk=None): ] out = Pk.T @ A @ Pk return out[0] if single else out + + def construct_AEq(self): + """ + Construct the linear equality matrix for quaternion constraints. + """ + AEq = np.zeros((17, 21), np.float64) + + # First 16 rows: identity constraints on first 16 variables + AEq[:16, :16] = np.eye(16, dtype=np.float64) + + # Extra columns 16:21 + extra = 0.25 * np.array( + [ + [-1, 1, 0, 0, 1], + [0, 0, 0, 0, 0], + [0, 0, 1, -1, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 1, 0, 0, -1], + [0, 0, 0, 0, 0], + [0, 0, 1, 1, 0], + [0, 0, 1, -1, 0], + [0, 0, 0, 0, 0], + [-1, -1, 0, 0, -1], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 1, 1, 0], + [0, 0, 0, 0, 0], + [1, -1, 0, 0, 1], + ], + dtype=np.float64, + ) + + AEq[:16, 16:] = extra + + # Last row: redundant trace/sum constraint + AEq[16, [0, 5, 10, 15]] = 1 + + return AEq diff --git a/tests/test_nug.py b/tests/test_nug.py index a221f36bd7..1eb65a71d6 100644 --- a/tests/test_nug.py +++ b/tests/test_nug.py @@ -3,12 +3,13 @@ from aspire.abinitio import CommonlineNUG from aspire.source import Simulation +from aspire.volume import CnSymmetricVolume, SymmetryGroup -DTYPE = [np.float32] -RESOLUTION = [48, 49] +DTYPE = [np.float32, pytest.param(np.float64, marks=pytest.mark.expensive)] +RESOLUTION = [48, pytest.param(48, marks=pytest.mark.expensive)] N_IMG = [5] OFFSETS = [0] -ORDER = [3, 4] +ORDER = [3, pytest.param(4, marks=pytest.mark.expensive)] PR = [False] SEED = 1980 @@ -68,12 +69,12 @@ def source(n_img, resolution, dtype, offsets, order): @pytest.fixture(scope="module") -def orient_est(src, proximal_refine): +def orient_est(source, proximal_refine): orient_est = CommonlineNUG( - src, + source, perform_pr=proximal_refine, ) - + orient_est.estimate_rotations() return orient_est @@ -86,4 +87,55 @@ def test_dtypes(orient_est): """ Check dtypes for each major step of the algorithm. """ - pass + assert orient_est.dtype == orient_est.src.dtype + + # Intermediate steps use doubles + for Ci in orient_est.C: + assert Ci.dtype == np.float64 + + for Xi in orient_est.X_est: + assert Xi.dtype == np.float64 + + assert orient_est.rotations.dtype == orient_est.dtype + + +def test_estimate_rotations_pairwise(orient_est): + """ """ + MSE = compare_rots_sym( + orient_est.rotations, orient_est.src.rotations, orient_est.sym_grp + ) + np.testing.assert_array_less(MSE, 0.3) + + +########### +# Helpers # +########### + + +def compare_rots_sym(R_est, R_true, sym): + N = R_true.shape[0] + sym_euler = SymmetryGroup.parse(sym).matrices + order = sym_euler.shape[0] + J = np.array([[-1, 0, 0], [0, -1, 0], [0, 0, 1]]) + error = np.zeros((N, N)) + errorJ = np.zeros((N, N)) + for i in range(N): + for j in range(N): + e = np.zeros(order) + eJ = np.zeros(order) + for s in range(order): + Rs = sym_euler[s] + e[s] = ( + np.linalg.norm(R_est[i].T @ R_est[j] - R_true[i].T @ Rs @ R_true[j]) + ** 2 + ) + eJ[s] = ( + np.linalg.norm( + R_est[i].T @ R_est[j] - J @ R_true[i].T @ Rs @ R_true[j] @ J + ) + ** 2 + ) + error[i, j] = e.min() + errorJ[i, j] = eJ.min() + E = min(error.sum(), errorJ.sum()) / N**2 + return E From 0359eab1f20d6ba736e7c66157f43e6b8bc1e786 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 22 May 2026 14:16:54 -0400 Subject: [PATCH 24/38] Support shifts --- src/aspire/abinitio/commonline_nug.py | 22 +++++++++++++++++----- tests/test_nug.py | 1 + 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index 944394c5a4..0bb1ebf8bd 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -11,7 +11,7 @@ from aspire.utils import Rotation, cart2sph, complex_type from aspire.volume import CnSymmetryGroup, DnSymmetryGroup, SymmetryGroup -from .commonline_utils import saff_kuijlaars +from .commonline_utils import _generate_shift_phase_and_filter, saff_kuijlaars logger = logging.getLogger(__name__) @@ -27,7 +27,7 @@ def __init__( symmetry=None, n_rad=None, n_theta=360, - max_shift=0, + max_shift=0.15, shift_step=1, mask=True, Lmax=12, @@ -98,6 +98,16 @@ def _build_full_pft(self): pf = self.pf self.pf_full = PolarFT.half_to_full(pf) + # Prepare the shift phases to try and generate filter for common-line detection + r_max = self.pf_full.shape[2] + self.shifts, self.shift_phases, h = _generate_shift_phase_and_filter( + r_max, self.max_shift, self.shift_step, self.dtype + ) + + # Apply bandpass filter, normalize each ray of each image + # Note that only use half of each ray + # self.pf_full = self._apply_filter_and_norm("ijk, k -> ijk", pf_full, r_max, h) + def estimate_rotations(self): self.compute_coeff() self.perform_admm() @@ -138,7 +148,6 @@ def fij(alpha, gamma, i, j, loss): # Si = Ii_hat[:, int(idxi)] # Sj = Ij_hat[:, int(idxj)] - # Using aspire PolarFT. Replace later Ii_hat = self.pf_full[i] Ij_hat = self.pf_full[j] idxi = np.round((alpha - np.pi / 2) * n_theta / 2 / np.pi) % n_theta @@ -146,9 +155,11 @@ def fij(alpha, gamma, i, j, loss): Si = Ii_hat[int(idxi)] Sj = Ij_hat[int(idxj)] - # norm_new = np.linalg.norm(Si - Sj, 1) - return np.linalg.norm(Si - Sj, 1) + # Apply shifts + Sj_shifted = self.shift_phases * Sj + norms = np.linalg.norm(Si[None] - Sj_shifted, 1, axis=1) + return norms.min() if loss == "wemd": dim_wave = len(wemd_embed(line_proj[:, 0, 0])) @@ -162,6 +173,7 @@ def fij(alpha, gamma, i, j, loss): Si = WE[:, int(idxi), i] Sj = WE[:, int(idxj), j] + return np.linalg.norm(Si - Sj, 1) alpha_grid = np.arange(2 * T, dtype=np.float64) * np.pi / T diff --git a/tests/test_nug.py b/tests/test_nug.py index 1eb65a71d6..6f1b96654f 100644 --- a/tests/test_nug.py +++ b/tests/test_nug.py @@ -72,6 +72,7 @@ def source(n_img, resolution, dtype, offsets, order): def orient_est(source, proximal_refine): orient_est = CommonlineNUG( source, + max_shift=0, perform_pr=proximal_refine, ) orient_est.estimate_rotations() From 154d45f63e490e482c2ac4c205d4931e458b6816 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 26 May 2026 11:35:26 -0400 Subject: [PATCH 25/38] test shifts --- tests/test_nug.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/test_nug.py b/tests/test_nug.py index 6f1b96654f..3b6ebb264b 100644 --- a/tests/test_nug.py +++ b/tests/test_nug.py @@ -6,9 +6,9 @@ from aspire.volume import CnSymmetricVolume, SymmetryGroup DTYPE = [np.float32, pytest.param(np.float64, marks=pytest.mark.expensive)] -RESOLUTION = [48, pytest.param(48, marks=pytest.mark.expensive)] -N_IMG = [5] -OFFSETS = [0] +RESOLUTION = [48, pytest.param(49, marks=pytest.mark.expensive)] +N_IMG = [15] +OFFSETS = [0, pytest.param(None, marks=pytest.mark.expensive)] ORDER = [3, pytest.param(4, marks=pytest.mark.expensive)] PR = [False] SEED = 1980 @@ -70,9 +70,15 @@ def source(n_img, resolution, dtype, offsets, order): @pytest.fixture(scope="module") def orient_est(source, proximal_refine): + max_shift = 0 + shift_step = 1 + if source.offsets.all() != 0: + max_shift = 0.20 + shift_step = 0.25 orient_est = CommonlineNUG( source, - max_shift=0, + max_shift=max_shift, + shift_step=shift_step, perform_pr=proximal_refine, ) orient_est.estimate_rotations() @@ -105,7 +111,7 @@ def test_estimate_rotations_pairwise(orient_est): MSE = compare_rots_sym( orient_est.rotations, orient_est.src.rotations, orient_est.sym_grp ) - np.testing.assert_array_less(MSE, 0.3) + np.testing.assert_array_less(MSE, 0.1) ########### From b67536392f6991e6d05f7576cd966db4881ccf8b Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 26 May 2026 13:17:04 -0400 Subject: [PATCH 26/38] Test Dn symmetry --- tests/test_nug.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/test_nug.py b/tests/test_nug.py index 3b6ebb264b..5d5be97ffb 100644 --- a/tests/test_nug.py +++ b/tests/test_nug.py @@ -3,7 +3,7 @@ from aspire.abinitio import CommonlineNUG from aspire.source import Simulation -from aspire.volume import CnSymmetricVolume, SymmetryGroup +from aspire.volume import CnSymmetricVolume, DnSymmetricVolume, SymmetryGroup DTYPE = [np.float32, pytest.param(np.float64, marks=pytest.mark.expensive)] RESOLUTION = [48, pytest.param(49, marks=pytest.mark.expensive)] @@ -12,6 +12,10 @@ ORDER = [3, pytest.param(4, marks=pytest.mark.expensive)] PR = [False] SEED = 1980 +VOLUME = [ + CnSymmetricVolume, + pytest.param(DnSymmetricVolume, marks=pytest.mark.expensive), +] @pytest.fixture(params=DTYPE, ids=lambda x: f"dtype={x}", scope="module") @@ -44,14 +48,19 @@ def proximal_refine(request): return request.param +@pytest.fixture(params=VOLUME, ids=lambda x: f"Volume={x}", scope="module") +def Volume(request): + return request.param + + ############ # Fixtures # ############ @pytest.fixture(scope="module") -def source(n_img, resolution, dtype, offsets, order): - vol = CnSymmetricVolume( +def source(n_img, resolution, dtype, offsets, order, Volume): + vol = Volume( L=resolution, order=order, C=1, K=100, dtype=dtype, seed=SEED ).generate() From 5611791bd116d9feef4c22257cc814b681667e86 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 29 May 2026 10:43:18 -0400 Subject: [PATCH 27/38] move compare_rots_sym to utils --- src/aspire/abinitio/__init__.py | 1 + src/aspire/abinitio/commonline_utils.py | 30 +++++++++++++++++++++ tests/test_nug.py | 36 +------------------------ 3 files changed, 32 insertions(+), 35 deletions(-) diff --git a/src/aspire/abinitio/__init__.py b/src/aspire/abinitio/__init__.py index cb68c8f544..42f40f5955 100644 --- a/src/aspire/abinitio/__init__.py +++ b/src/aspire/abinitio/__init__.py @@ -2,6 +2,7 @@ from .J_sync import JSync from .commonline_utils import ( build_outer_products, + compare_rots_sym, g_sync, saff_kuijlaars, ) diff --git a/src/aspire/abinitio/commonline_utils.py b/src/aspire/abinitio/commonline_utils.py index 5fe88a21b1..bdf955bd4d 100644 --- a/src/aspire/abinitio/commonline_utils.py +++ b/src/aspire/abinitio/commonline_utils.py @@ -5,6 +5,7 @@ from aspire.operators import PolarFT from aspire.utils import J_conjugate, Rotation, all_pairs, anorm, cyclic_rotations, tqdm +from aspire.volume import SymmetryGroup logger = logging.getLogger(__name__) @@ -418,3 +419,32 @@ def build_outer_products(n, dtype): viis[i] = np.outer(gt_vis[i], gt_vis[i]) return vijs, viis, gt_vis + + +def compare_rots_sym(R_est, R_true, sym): + N = R_true.shape[0] + sym_euler = SymmetryGroup.parse(sym).matrices + order = sym_euler.shape[0] + J = np.array([[-1, 0, 0], [0, -1, 0], [0, 0, 1]]) + error = np.zeros((N, N)) + errorJ = np.zeros((N, N)) + for i in range(N): + for j in range(N): + e = np.zeros(order) + eJ = np.zeros(order) + for s in range(order): + Rs = sym_euler[s] + e[s] = ( + np.linalg.norm(R_est[i].T @ R_est[j] - R_true[i].T @ Rs @ R_true[j]) + ** 2 + ) + eJ[s] = ( + np.linalg.norm( + R_est[i].T @ R_est[j] - J @ R_true[i].T @ Rs @ R_true[j] @ J + ) + ** 2 + ) + error[i, j] = e.min() + errorJ[i, j] = eJ.min() + E = min(error.sum(), errorJ.sum()) / N**2 + return E diff --git a/tests/test_nug.py b/tests/test_nug.py index 5d5be97ffb..62c86834a0 100644 --- a/tests/test_nug.py +++ b/tests/test_nug.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from aspire.abinitio import CommonlineNUG +from aspire.abinitio import CommonlineNUG, compare_rots_sym from aspire.source import Simulation from aspire.volume import CnSymmetricVolume, DnSymmetricVolume, SymmetryGroup @@ -121,37 +121,3 @@ def test_estimate_rotations_pairwise(orient_est): orient_est.rotations, orient_est.src.rotations, orient_est.sym_grp ) np.testing.assert_array_less(MSE, 0.1) - - -########### -# Helpers # -########### - - -def compare_rots_sym(R_est, R_true, sym): - N = R_true.shape[0] - sym_euler = SymmetryGroup.parse(sym).matrices - order = sym_euler.shape[0] - J = np.array([[-1, 0, 0], [0, -1, 0], [0, 0, 1]]) - error = np.zeros((N, N)) - errorJ = np.zeros((N, N)) - for i in range(N): - for j in range(N): - e = np.zeros(order) - eJ = np.zeros(order) - for s in range(order): - Rs = sym_euler[s] - e[s] = ( - np.linalg.norm(R_est[i].T @ R_est[j] - R_true[i].T @ Rs @ R_true[j]) - ** 2 - ) - eJ[s] = ( - np.linalg.norm( - R_est[i].T @ R_est[j] - J @ R_true[i].T @ Rs @ R_true[j] @ J - ) - ** 2 - ) - error[i, j] = e.min() - errorJ[i, j] = eJ.min() - E = min(error.sum(), errorJ.sum()) / N**2 - return E From 3324cacf544498ce9b7159e47e7b190d00dccce6 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 4 Jun 2026 14:34:51 -0400 Subject: [PATCH 28/38] Add generalized g_sync --- src/aspire/abinitio/__init__.py | 1 + src/aspire/abinitio/commonline_utils.py | 181 +++++++++++++++++++++++- tests/test_nug.py | 10 +- tests/test_orient_symmetric.py | 2 +- 4 files changed, 185 insertions(+), 9 deletions(-) diff --git a/src/aspire/abinitio/__init__.py b/src/aspire/abinitio/__init__.py index 42f40f5955..c5b2a23d0e 100644 --- a/src/aspire/abinitio/__init__.py +++ b/src/aspire/abinitio/__init__.py @@ -4,6 +4,7 @@ build_outer_products, compare_rots_sym, g_sync, + g_sync_finite_group, saff_kuijlaars, ) from .commonline_base import Orient3D diff --git a/src/aspire/abinitio/commonline_utils.py b/src/aspire/abinitio/commonline_utils.py index bdf955bd4d..1d73b8fa16 100644 --- a/src/aspire/abinitio/commonline_utils.py +++ b/src/aspire/abinitio/commonline_utils.py @@ -5,7 +5,7 @@ from aspire.operators import PolarFT from aspire.utils import J_conjugate, Rotation, all_pairs, anorm, cyclic_rotations, tqdm -from aspire.volume import SymmetryGroup +from aspire.volume import CnSymmetryGroup, SymmetryGroup logger = logging.getLogger(__name__) @@ -320,10 +320,31 @@ def saff_kuijlaars(N): return mesh -def g_sync(rots, order, rots_gt): +def g_sync(rots, rots_gt, symmetry): """ Given ground truth rotations, synchronize estimated rotations over - symmetry group elements. + symmetry group elements. This method dispatches to either the faster + cyclic implementation for Cn symmetry or the generalized version for + all other symmetries. + + :param rots: Estimated rotation matrices + :param rots_gt: Ground truth rotation matrices + :param symmetry: The symmetry of the underlying molecule. + + :return: g-synchronized ground truth rotations. + """ + sym_grp = SymmetryGroup.parse(symmetry) + + if isinstance(sym_grp, CnSymmetryGroup): + return g_sync_cyclic(rots, rots_gt, symmetry) + # else + return g_sync_finite_group(rots, rots_gt, symmetry) + + +def g_sync_cyclic(rots, rots_gt, symmetry): + """ + Given ground truth rotations, synchronize estimated rotations over + cyclic symmetry group elements. Every estimated rotation might be a version of the ground truth rotation rotated by g^{s_i}, where s_i = 0, 1, ..., order. This method synchronizes the @@ -331,8 +352,8 @@ def g_sync(rots, order, rots_gt): to all estimates for error analysis. :param rots: Estimated rotation matrices - :param order: The cyclic order asssociated with the symmetry of the underlying molecule. - :param rots_gt: Ground truth rotation matrices. + :param rots_gt: Ground truth rotation matrices + :param symmetry: The symmetry of the underlying molecule. :return: g-synchronized ground truth rotations. """ @@ -342,8 +363,8 @@ def g_sync(rots, order, rots_gt): n_img = len(rots) dtype = rots.dtype - rots_symm = cyclic_rotations(order, dtype).matrices - + rots_symm = SymmetryGroup.parse(symmetry).matrices + order = len(rots_symm) A_g = np.zeros((n_img, n_img), dtype=complex) pairs = all_pairs(n_img) @@ -386,6 +407,152 @@ def g_sync(rots, order, rots_gt): return rots_gt_sync +def g_sync_finite_group(rots, rots_gt, symmetry): + """ + Synchronize ground-truth rotations over a finite symmetry group. + + This is a finite-group generalization of cyclic synchronization. The + pairwise matching step estimates relative symmetry elements between image + pairs. The spectral synchronization step then recovers one symmetry + element per image that is globally consistent with those pairwise estimates. + + Unlike the cyclic case, the relative symmetry cannot generally be encoded + as a scalar complex phase. For non-commutative groups such as D_n, we instead + represent each group element by its left-regular permutation matrix. + + :param rots: Estimated rotation matrices + :param rots_gt: Ground truth rotation matrices + :param symmetry: The symmetry of the underlying molecule. + + :return: g-synchronized ground truth rotations. + """ + assert len(rots) == len( + rots_gt + ), "Number of estimates not equal to number of references." + + n_img = len(rots) + dtype = rots.dtype + + # All matrices in the symmetry group. + G = SymmetryGroup.parse(symmetry).matrices + n_group = len(G) + + def find_group_index(A): + """Return the index of the group matrix closest to A.""" + dists = np.linalg.norm(G - A, axis=(1, 2)) + return np.argmin(dists) + + # Build multiplication and inverse tables for the symmetry group. + # mult[a, b] is the index c such that: + # G[a] @ G[b] == G[c] + # inv[a] is the index b such that: + # G[a] @ G[b] == identity + mult = np.empty((n_group, n_group), dtype=int) + inv = np.empty(n_group, dtype=int) + + for a in range(n_group): + inv[a] = find_group_index(G[a].T) + + for b in range(n_group): + mult[a, b] = find_group_index(G[a] @ G[b]) + + # Build the left-regular representation of the group. + # + # reps[a] is an n_group x n_group permutation matrix representing left + # multiplication by group element a: + # + # reps[a] @ e_b = e_{a b} + # + # This lets us store arbitrary finite-group relative elements in a block + # synchronization matrix. This is the key generalization beyond cyclic + # scalar phases in g_sync_Cn. + reps = np.zeros((n_group, n_group, n_group), dtype=float) + + for a in range(n_group): + for b in range(n_group): + reps[a, mult[a, b], b] = 1.0 + + # Block synchronization matrix. + # A is made of n_img x n_img blocks, each of size n_group x n_group. + # Block (i, j) stores the representation of the estimated relative + # symmetry element between images i and j. + A = np.zeros((n_img * n_group, n_img * n_group), dtype=float) + + # The relative symmetry from an image to itself is the identity. + for i in range(n_img): + sl_i = slice(i * n_group, (i + 1) * n_group) + A[sl_i, sl_i] = np.eye(n_group) + + for i, j in all_pairs(n_img): + # Estimated relative rotation. + Ri = rots[i] + Rj = rots[j] + Rij = Ri.T @ Rj + + # Ground-truth rotations for this pair. + Ri_gt = rots_gt[i] + Rj_gt = rots_gt[j] + + # Try every symmetry element and find which one makes the + # ground-truth relative rotation most closely match the estimated + # relative rotation. + diffs = np.zeros(n_group, dtype=float) + + for s, g_s in enumerate(G): + Rij_gt = Ri_gt.T @ g_s @ Rj_gt + diffs[s] = min( + np.linalg.norm(Rij - Rij_gt), + np.linalg.norm(Rij - J_conjugate(Rij_gt)), + ) + + # Estimates relative group element h_i^{-1} h_j. + idx_ij = np.argmin(diffs) + + sl_i = slice(i * n_group, (i + 1) * n_group) + sl_j = slice(j * n_group, (j + 1) * n_group) + + # Store the pairwise relative group elements in blocks (i, j)/(j, i). + A[sl_i, sl_j] = reps[idx_ij] + A[sl_j, sl_i] = reps[idx_ij].T + + # Spectral synchronization: + # In the noiseless case, this block matrix has a top eigenspace of + # dimension n_group. That eigenspace encodes the unknown per-image group + # elements, up to one global group action. + _, eig_vecs = np.linalg.eigh(A) + V = eig_vecs[:, -n_group:] + + # Fix the global gauge using image 0 as a reference: + # The recovered group elements are only determined up to a common global + # symmetry. This is fine for error analysis, because a single remaining + # global rotation/symmetry can still be applied later. + V0 = V[:n_group, :] + V0_pinv = np.linalg.pinv(V0) + + rots_gt_sync = np.zeros_like(rots_gt, dtype=dtype) + + for i, rot_gt in enumerate(rots_gt): + sl_i = slice(i * n_group, (i + 1) * n_group) + Vi = V[sl_i, :] + + # Compare image i's eigenspace block to the reference block. + # Ideally, this matrix is close to the representation of one group + # element. Noise makes it approximate, so we round it to the nearest + # valid group representation below. + M_i = Vi @ V0_pinv + + # Round to nearest group representation. + dists = np.linalg.norm(M_i - reps, axis=(1, 2)) + q_i = np.argmin(dists) + + # Apply the synchronized group element to the ground-truth rotation. + h_i_sync = inv[q_i] + + rots_gt_sync[i] = G[h_i_sync] @ rot_gt + + return rots_gt_sync + + def build_outer_products(n, dtype): """ Builds sets of outer products of 3rd rows of rotation matrices. diff --git a/tests/test_nug.py b/tests/test_nug.py index 62c86834a0..c83eb57303 100644 --- a/tests/test_nug.py +++ b/tests/test_nug.py @@ -1,8 +1,9 @@ import numpy as np import pytest -from aspire.abinitio import CommonlineNUG, compare_rots_sym +from aspire.abinitio import CommonlineNUG, compare_rots_sym, g_sync from aspire.source import Simulation +from aspire.utils import mean_aligned_angular_distance from aspire.volume import CnSymmetricVolume, DnSymmetricVolume, SymmetryGroup DTYPE = [np.float32, pytest.param(np.float64, marks=pytest.mark.expensive)] @@ -121,3 +122,10 @@ def test_estimate_rotations_pairwise(orient_est): orient_est.rotations, orient_est.src.rotations, orient_est.sym_grp ) np.testing.assert_array_less(MSE, 0.1) + + +def test_estimate_rotations(orient_est): + gt_rots_synced = g_sync( + orient_est.rotations, orient_est.src.rotations, orient_est.sym_grp + ) + mean_aligned_angular_distance(orient_est.rotations, gt_rots_synced, 8.0) diff --git a/tests/test_orient_symmetric.py b/tests/test_orient_symmetric.py index ed9c5a6904..df86ba020a 100644 --- a/tests/test_orient_symmetric.py +++ b/tests/test_orient_symmetric.py @@ -123,7 +123,7 @@ def test_estimate_rotations(n_img, L, order, dtype): rots_gt = src.rotations # g-synchronize ground truth rotations. - rots_gt_sync = g_sync(rots_est, order, rots_gt) + rots_gt_sync = g_sync(rots_est, rots_gt, src.symmetry_group) # Register estimates to ground truth rotations and check that the # mean angular distance between them is less than 3 degrees. From 43b8149438bbb539bebc22adce3584bada01f508 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 4 Jun 2026 14:37:29 -0400 Subject: [PATCH 29/38] tox --- src/aspire/abinitio/commonline_utils.py | 2 +- tests/test_nug.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aspire/abinitio/commonline_utils.py b/src/aspire/abinitio/commonline_utils.py index 1d73b8fa16..3d2745b656 100644 --- a/src/aspire/abinitio/commonline_utils.py +++ b/src/aspire/abinitio/commonline_utils.py @@ -4,7 +4,7 @@ from numpy.linalg import eigh, norm from aspire.operators import PolarFT -from aspire.utils import J_conjugate, Rotation, all_pairs, anorm, cyclic_rotations, tqdm +from aspire.utils import J_conjugate, Rotation, all_pairs, anorm, tqdm from aspire.volume import CnSymmetryGroup, SymmetryGroup logger = logging.getLogger(__name__) diff --git a/tests/test_nug.py b/tests/test_nug.py index c83eb57303..a2f5adad1a 100644 --- a/tests/test_nug.py +++ b/tests/test_nug.py @@ -4,7 +4,7 @@ from aspire.abinitio import CommonlineNUG, compare_rots_sym, g_sync from aspire.source import Simulation from aspire.utils import mean_aligned_angular_distance -from aspire.volume import CnSymmetricVolume, DnSymmetricVolume, SymmetryGroup +from aspire.volume import CnSymmetricVolume, DnSymmetricVolume DTYPE = [np.float32, pytest.param(np.float64, marks=pytest.mark.expensive)] RESOLUTION = [48, pytest.param(49, marks=pytest.mark.expensive)] From e29a80eba3e4f1014a1206f06ba3b48a4f5bf4c0 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 8 Jun 2026 14:34:55 -0400 Subject: [PATCH 30/38] Add g_sync test --- src/aspire/abinitio/commonline_utils.py | 3 ++ tests/test_commonline_utils.py | 53 ++++++++++++++++++++++++- tests/test_orient_d2.py | 2 +- 3 files changed, 55 insertions(+), 3 deletions(-) diff --git a/src/aspire/abinitio/commonline_utils.py b/src/aspire/abinitio/commonline_utils.py index 3d2745b656..ec4c0cfbf7 100644 --- a/src/aspire/abinitio/commonline_utils.py +++ b/src/aspire/abinitio/commonline_utils.py @@ -394,6 +394,9 @@ def g_sync_cyclic(rots, rots_gt, symmetry): _, eig_vecs = eigh(A_g) leading_eig_vec = eig_vecs[:, -1] + # Remove arbitrary global phase (eigh returns +-1 eigenvector) + leading_eig_vec *= np.exp(-1j * np.angle(leading_eig_vec[0])) + angles = np.exp(1j * 2 * np.pi / order * np.arange(order)) rots_gt_sync = np.zeros((n_img, 3, 3), dtype=dtype) diff --git a/tests/test_commonline_utils.py b/tests/test_commonline_utils.py index 64d015d1dd..818d829c6a 100644 --- a/tests/test_commonline_utils.py +++ b/tests/test_commonline_utils.py @@ -1,13 +1,20 @@ import numpy as np import pytest -from aspire.abinitio import JSync +from aspire.abinitio import JSync, g_sync from aspire.abinitio.commonline_utils import ( _complete_third_row_to_rot, _estimate_third_rows, build_outer_products, ) -from aspire.utils import J_conjugate, Rotation, randn, utest_tolerance +from aspire.utils import ( + J_conjugate, + Rotation, + mean_aligned_angular_distance, + randn, + utest_tolerance, +) +from aspire.volume import SymmetryGroup DTYPES = [np.float32, np.float64] @@ -116,3 +123,45 @@ def test_J_sync(dtype): np.testing.assert_allclose(Rijs_sync, Rijs_gt) assert Rijs_sync.dtype == dtype + + +@pytest.mark.parametrize("symmetry", ["C3", "C4", "D3", "D4", "T", "O"]) +def test_g_sync(symmetry): + n = 100 + dtype = np.float64 + + # Get symmetry group matrices + gs = SymmetryGroup.parse(symmetry).matrices + + # Build set of ground truth rotations + gt_rots = Rotation.generate_random_rotations(n, dtype=dtype) + + # Build set of estimates which are close to ground truth + # by generating set of small perturbation rotations to apply + # to ground truth rotations. + target_mean_deg = 2.0 + axes = np.random.normal(size=(n, 3)).astype(dtype) + axes /= np.linalg.norm(axes, axis=1, keepdims=True) + angles = np.random.uniform(0, 2 * np.deg2rad(target_mean_deg), n).astype(dtype) + delta_rots = Rotation.from_rotvec(axes * angles[:, None], dtype=dtype) + perturbed_rots = Rotation(delta_rots.matrices @ gt_rots.matrices) + + # Get mean ang dist for aligned estimates + # and check we're close to target. + og_maad = mean_aligned_angular_distance(perturbed_rots, gt_rots) + np.testing.assert_array_less(abs(og_maad - target_mean_deg), 0.2) + + # Simulate symmetry desynchronization + g_idx = np.random.randint(len(gs), size=n) + desynced_rots = Rotation(gs[g_idx] @ perturbed_rots.matrices) + + # Mean aligned angular distance of unsynced rots should be bad + np.testing.assert_array_less( + 10 * og_maad, mean_aligned_angular_distance(desynced_rots, gt_rots) + ) + + # Perform g_sync and check that mean aligned angular distance + # matches ground truth MAAD to within .1 degrees. + rots_gt_synced = g_sync(desynced_rots, gt_rots, symmetry) + est_maad = mean_aligned_angular_distance(desynced_rots, rots_gt_synced) + np.testing.assert_array_less(abs(og_maad - est_maad), 0.1) diff --git a/tests/test_orient_d2.py b/tests/test_orient_d2.py index cc2013d07a..2d1bb73d2b 100644 --- a/tests/test_orient_d2.py +++ b/tests/test_orient_d2.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from aspire.abinitio import CLSymmetryD2 +from aspire.abinitio import CLSymmetryD2, g_sync from aspire.source import Simulation from aspire.utils import ( J_conjugate, From 40937f4b4494108fcc4247b2b48a08e60b07024b Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 8 Jun 2026 15:20:11 -0400 Subject: [PATCH 31/38] Use new g_sync in D2 test. Remove unused method --- tests/test_orient_d2.py | 74 +---------------------------------------- 1 file changed, 1 insertion(+), 73 deletions(-) diff --git a/tests/test_orient_d2.py b/tests/test_orient_d2.py index 2d1bb73d2b..1c05691c94 100644 --- a/tests/test_orient_d2.py +++ b/tests/test_orient_d2.py @@ -103,7 +103,7 @@ def test_estimate_rotations(orient_est): rots_gt = orient_est.src.rotations # g-sync ground truth rotations. - rots_gt_sync = g_sync_d2(rots_est, rots_gt) + rots_gt_sync = g_sync(rots_est, rots_gt, "D2") # Register estimates to ground truth rotations and check that the mean angular # distance between them is less than 5 degrees. @@ -382,78 +382,6 @@ def test_sync_signs(orient_est): assert rots_est.dtype == orient_est.dtype -#################### -# Helper Functions # -#################### - - -def g_sync_d2(rots, rots_gt): - """ - Every estimated rotation might be a version of the ground truth rotation - rotated by g^{s_i}, where s_i = 0, 1, ..., order. This method synchronizes the - ground truth rotations so that only a single global rotation need be applied - to all estimates for error analysis. - - :param rots: Estimated rotation matrices - :param rots_gt: Ground truth rotation matrices. - - :return: g-synchronized ground truth rotations. - """ - assert len(rots) == len( - rots_gt - ), "Number of estimates not equal to number of references." - n_img = len(rots) - dtype = rots.dtype - - rots_symm = DnSymmetryGroup(2).matrices.astype(dtype, copy=False)[[0, 2, 1, 3]] - order = len(rots_symm) - - A_g = np.zeros((n_img, n_img), dtype=complex) - - pairs = all_pairs(n_img) - - for i, j in pairs: - Ri = rots[i] - Rj = rots[j] - Rij = Ri.T @ Rj - - Ri_gt = rots_gt[i] - Rj_gt = rots_gt[j] - - diffs = np.zeros(order) - for s, g_s in enumerate(rots_symm): - Rij_gt = Ri_gt.T @ g_s @ Rj_gt - diffs[s] = min( - [ - np.linalg.norm(Rij - Rij_gt), - np.linalg.norm(Rij - J_conjugate(Rij_gt)), - ] - ) - - idx = np.argmin(diffs) - - A_g[i, j] = np.exp(-1j * 2 * np.pi / order * idx) - - # A_g(k,l) is exp(-j(-theta_k+theta_l)) - # Diagonal elements correspond to exp(-i*0) so put 1. - # This is important only for verification purposes that spectrum is (K,0,0,0...,0). - A_g += np.conj(A_g).T + np.eye(n_img) - _, eig_vecs = np.linalg.eigh(A_g) - leading_eig_vec = eig_vecs[:, -1] - - angles = np.exp(1j * 2 * np.pi / order * np.arange(order)) - rots_gt_sync = np.zeros((n_img, 3, 3), dtype=dtype) - - for i, rot_gt in enumerate(rots_gt): - # Since the closest ccw or cw rotation are just as good, - # we take the absolute value of the angle differences. - angle_dists = np.abs(np.angle(leading_eig_vec[i] / angles)) - power_g_Ri = np.argmin(angle_dists) - rots_gt_sync[i] = rots_symm[power_g_Ri] @ rot_gt - - return rots_gt_sync - - def build_cl_from_source(source): # Search for common lines over less shifts for 0 offsets. max_shift = 0 From 52514406ebcb1e5ad74f45c0cc7c3643b2491437 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 8 Jun 2026 15:22:36 -0400 Subject: [PATCH 32/38] remove unused imports --- tests/test_orient_d2.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_orient_d2.py b/tests/test_orient_d2.py index 1c05691c94..fba062ab9a 100644 --- a/tests/test_orient_d2.py +++ b/tests/test_orient_d2.py @@ -7,11 +7,10 @@ J_conjugate, Random, Rotation, - all_pairs, mean_aligned_angular_distance, utest_tolerance, ) -from aspire.volume import DnSymmetricVolume, DnSymmetryGroup +from aspire.volume import DnSymmetricVolume ############## # Parameters # From 9729fd5a1838f7628430fdb5e3b0548e19f0fbb2 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 9 Jun 2026 13:47:41 -0400 Subject: [PATCH 33/38] remove fast_radon_transform and wemd code path --- src/aspire/abinitio/commonline_nug.py | 104 ++++---------------------- tests/test_nug.py | 2 +- 2 files changed, 16 insertions(+), 90 deletions(-) diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index 0bb1ebf8bd..3114160fce 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -5,9 +5,8 @@ from scipy.special import factorial from aspire.abinitio import Orient3D -from aspire.nufft import nufft -from aspire.numeric import fft, xp -from aspire.operators import PolarFT, wemd_embed +from aspire.numeric import xp +from aspire.operators import PolarFT from aspire.utils import Rotation, cart2sph, complex_type from aspire.volume import CnSymmetryGroup, DnSymmetryGroup, SymmetryGroup @@ -31,7 +30,6 @@ def __init__( shift_step=1, mask=True, Lmax=12, - loss="l1", T=36, max_iter=501, rho=0.05, @@ -64,7 +62,6 @@ def __init__( ) self.Lmax = Lmax - self.loss = loss self.T = T self.max_iter = max_iter self.rho = rho @@ -120,61 +117,24 @@ def estimate_rotations(self): def compute_coeff(self): # compute the coefficient matrix - Img = self.src.images[:] N = self.n_img n_theta = self.n_theta - loss = self.loss Lmax = self.Lmax T = self.T - angular_sampling = np.arange(0, 360, 1, dtype=np.float64) - - # Using ASPIRE Image.project(). Leaving original method in comments for now. - line_proj = Img.project(angular_sampling).asnumpy().T - - # line_proj = np.zeros((L, n_theta, N)) - # Img_pft = np.zeros((L, n_theta, N), dtype=complex) - # Img = Img.asnumpy() - # for n in range(N): - # line_proj[:, :, n], Img_pft[:, :, n] = self.fast_radon_transform( - # Img[n], angular_sampling - # ) - def fij(alpha, gamma, i, j, loss): - if loss == "l1": - # Ii_hat = Img_pft[:, :, i] - # Ij_hat = Img_pft[:, :, j] - # idxi = np.round((alpha - np.pi / 2) * n_theta / 2 / np.pi) % n_theta - # idxj = np.round((-gamma - np.pi / 2) * n_theta / 2 / np.pi) % n_theta - - # Si = Ii_hat[:, int(idxi)] - # Sj = Ij_hat[:, int(idxj)] - - Ii_hat = self.pf_full[i] - Ij_hat = self.pf_full[j] - idxi = np.round((alpha - np.pi / 2) * n_theta / 2 / np.pi) % n_theta - idxj = np.round((-gamma - np.pi / 2) * n_theta / 2 / np.pi) % n_theta - - Si = Ii_hat[int(idxi)] - Sj = Ij_hat[int(idxj)] - - # Apply shifts - Sj_shifted = self.shift_phases * Sj - norms = np.linalg.norm(Si[None] - Sj_shifted, 1, axis=1) - return norms.min() - - if loss == "wemd": - dim_wave = len(wemd_embed(line_proj[:, 0, 0])) - WE = np.zeros((dim_wave, n_theta, N), dtype=np.float64) - for i in range(N): - for theta in range(n_theta): - WE[:, theta, i] = wemd_embed(line_proj[:, theta, i]) - idxi = np.round((alpha - np.pi / 2) * n_theta / 2 / np.pi) % n_theta - idxj = np.round((-gamma - np.pi / 2) * n_theta / 2 / np.pi) % n_theta + def fij(alpha, gamma, i, j): + Ii_hat = self.pf_full[i] + Ij_hat = self.pf_full[j] + idxi = np.round((alpha - np.pi / 2) * n_theta / 2 / np.pi) % n_theta + idxj = np.round((-gamma - np.pi / 2) * n_theta / 2 / np.pi) % n_theta - Si = WE[:, int(idxi), i] - Sj = WE[:, int(idxj), j] + Si = Ii_hat[int(idxi)] + Sj = Ij_hat[int(idxj)] - return np.linalg.norm(Si - Sj, 1) + # Apply shifts + Sj_shifted = self.shift_phases * Sj + norms = np.linalg.norm(Si[None] - Sj_shifted, 1, axis=1) + return norms.min() alpha_grid = np.arange(2 * T, dtype=np.float64) * np.pi / T beta_grid = (2 * np.arange(2 * T, dtype=np.float64) + 1) * np.pi / 4 / T @@ -216,7 +176,7 @@ def fijhat_k(k, F): Fij = np.zeros((2 * T, 2 * T), dtype=np.float64) for j1 in range(2 * T): for j2 in range(2 * T): - Fij[j1, j2] = fij(alpha_grid[j1], gamma_grid[j2], i, j, loss) + Fij[j1, j2] = fij(alpha_grid[j1], gamma_grid[j2], i, j) for k in range(1, Lmax + 1): dk = 2 * k + 1 C[k - 1][j * dk : (j + 1) * dk, i * dk : (i + 1) * dk] = fijhat_k( @@ -229,7 +189,7 @@ def fijhat_k(k, F): Fii = np.zeros((2 * T, 2 * T), dtype=np.float64) for j1 in range(2 * T): for j2 in range(2 * T): - Fii[j1, j2] = fij(alpha_grid[j1], gamma_grid[j2], i, i, loss) + Fii[j1, j2] = fij(alpha_grid[j1], gamma_grid[j2], i, i) for k in range(1, Lmax + 1): dk = 2 * k + 1 C[k - 1][i * dk : (i + 1) * dk, i * dk : (i + 1) * dk] = fijhat_k( @@ -246,40 +206,6 @@ def fijhat_k(k, F): self.C = C - @staticmethod - def fast_radon_transform(array, angles, use_ramp=False): - - angles = np.array(angles).flatten() - img_size = array.shape[1] - rads = angles / 180 * np.pi - y_idx = np.arange(-img_size / 2, img_size / 2) / img_size * 2 - x_theta = y_idx[:, np.newaxis] * np.sin(rads)[np.newaxis, :] - y_theta = y_idx[:, np.newaxis] * np.cos(rads)[np.newaxis, :] - - pts = np.pi * np.vstack( - [ - x_theta.flatten(), - y_theta.flatten(), - ] - ) - pts = pts.astype(array.dtype) - - # array = array.astype(np.float32) - lines_f = nufft(array, pts).reshape((img_size, -1)) - - if img_size % 2 == 0: - lines_f[0, :] = 0 - - if use_ramp: - freqs = np.abs(np.pi * y_idx) - lines_f *= freqs[:, np.newaxis] - - projections = np.real( - xp.asnumpy(fft.centered_ifft(xp.asarray(lines_f), axis=0)) - ) - - return projections, lines_f - def complex2real(self, ell): # compute transformation matrices that convert complex representations to real ones diml = 2 * ell + 1 diff --git a/tests/test_nug.py b/tests/test_nug.py index a2f5adad1a..b9699e7f66 100644 --- a/tests/test_nug.py +++ b/tests/test_nug.py @@ -6,7 +6,7 @@ from aspire.utils import mean_aligned_angular_distance from aspire.volume import CnSymmetricVolume, DnSymmetricVolume -DTYPE = [np.float32, pytest.param(np.float64, marks=pytest.mark.expensive)] +DTYPE = [np.float64, pytest.param(np.float32, marks=pytest.mark.expensive)] RESOLUTION = [48, pytest.param(49, marks=pytest.mark.expensive)] N_IMG = [15] OFFSETS = [0, pytest.param(None, marks=pytest.mark.expensive)] From c2ad19e19ae8ab166f1e27ae33db3a24e9d612c7 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 9 Jun 2026 15:12:24 -0400 Subject: [PATCH 34/38] cleanup --- src/aspire/abinitio/commonline_nug.py | 25 ++++--------------------- 1 file changed, 4 insertions(+), 21 deletions(-) diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index 3114160fce..a236cac83f 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -17,7 +17,8 @@ class CommonlineNUG(Orient3D): """ - Class to estimate 3D orientations using non-uqique games. + Class to estimate 3D orientations using non-uqique games for molecules with cyclic + or dihedral symmetry. """ def __init__( @@ -95,16 +96,12 @@ def _build_full_pft(self): pf = self.pf self.pf_full = PolarFT.half_to_full(pf) - # Prepare the shift phases to try and generate filter for common-line detection + # Prepare the shift phases for common-line detection r_max = self.pf_full.shape[2] - self.shifts, self.shift_phases, h = _generate_shift_phase_and_filter( + self.shifts, self.shift_phases, _ = _generate_shift_phase_and_filter( r_max, self.max_shift, self.shift_step, self.dtype ) - # Apply bandpass filter, normalize each ray of each image - # Note that only use half of each ray - # self.pf_full = self._apply_filter_and_norm("ijk, k -> ijk", pf_full, r_max, h) - def estimate_rotations(self): self.compute_coeff() self.perform_admm() @@ -709,8 +706,6 @@ def ADMM_preprocessing(self, C): D1 = d1[-1] # AE and bE for quaternion constraints - # AEq = xp.asarray(loadmat("data/Eq_constraints/AEqJ.mat")["AEq"]) - # AEqAEqtinv = xp.asarray(loadmat("data/Eq_constraints/AEqJ.mat")["AEqAEqtinv"]) AEq = xp.asarray(self.construct_AEq()) AEqAEqtinv = xp.linalg.pinv(AEq @ AEq.T) @@ -721,16 +716,6 @@ def ADMM_preprocessing(self, C): # AI and bI W0, W1, Ngrid = self.compute_fejer_weights() - # AI_mat=np.zeros((Ngrid,D0+D1)) - # for p in range(Ngrid): - # w0=np.zeros(D0); w1=np.zeros(D1) - # for k in range(1,Lmax+1): - # w0[d0[k-1]:d0[k]]=(Lmax-k+2)*(Lmax-k+1)*(k+0.5)*W0[k-1][p].T.reshape(-1) - # w1[d1[k-1]:d1[k]]=(Lmax-k+2)*(Lmax-k+1)*(k+0.5)*W1[k-1][p].T.reshape(-1) - # # this needs double checking - # AI_mat[p,:d0[-1]]=w0; AI_mat[p,d0[-1]:]=w1 - # AI_mat=xp.asarray(AI_mat) / 10; - # bI=-(Lmax+2)*(Lmax+1)/2 / 10 AI_mat_offdiag = np.zeros((Ngrid, D0 + D1), dtype=np.float64) for p in range(Ngrid): w0 = np.zeros(D0, dtype=np.float64) @@ -874,7 +859,6 @@ def permutek_block(Ak, k): return W0, W1, Ngrid def discretize_SO3(self): - # S2 = loadmat("design20.mat")["design"] S2 = saff_kuijlaars(self.S2_grid) S2_size = S2.shape[0] @@ -1354,7 +1338,6 @@ def largest_eigenvalue(AI, Ngrid, N): z = AI @ (AI.T @ z) Lambda += 2000 logger.info("Largest eigenvalue of AIAIT is approximately %1.2f" % Lambda) - # Lambda=xp.linalg.eigvalsh(AI@AI.T)[-1]; print(Lambda) return Lambda def compute_rank(self, Lmax): From 7ef42365a629398e50e6c10f89fc456fd1d1e7b2 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 10 Jun 2026 11:13:24 -0400 Subject: [PATCH 35/38] Add function docstrings --- src/aspire/abinitio/commonline_nug.py | 89 ++++++++++++++++++++++++--- 1 file changed, 81 insertions(+), 8 deletions(-) diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index a236cac83f..3fe18e2d84 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -17,8 +17,7 @@ class CommonlineNUG(Orient3D): """ - Class to estimate 3D orientations using non-uqique games for molecules with cyclic - or dihedral symmetry. + Estimate orientations of cyclically or dihedrally symmetric molecules using the non-unique games framework. """ def __init__( @@ -44,7 +43,7 @@ def __init__( **kwargs, ): """ - Initialize object for estimating 3D orientations for symmetric molecules. + Initialize the symmetric NUG orientation estimator. :param src: The source object of 2D denoised or class-averaged images with metadata :param symmetry: A string, ie. 'C3', indicating the symmetry type. @@ -93,6 +92,9 @@ def __init__( self._build_full_pft() def _build_full_pft(self): + """ + Construct the full polar Fourier transforms and candidate shift phases. + """ pf = self.pf self.pf_full = PolarFT.half_to_full(pf) @@ -103,6 +105,9 @@ def _build_full_pft(self): ) def estimate_rotations(self): + """ + Estimate rotations by computing NUG coefficients, solving the SDP relaxation, and recovering Euler angles. + """ self.compute_coeff() self.perform_admm() self.euler_est() @@ -113,6 +118,9 @@ def estimate_rotations(self): ####################### def compute_coeff(self): + """ + Compute the truncated Fourier coefficient matrices of the pairwise common-line losses. + """ # compute the coefficient matrix N = self.n_img n_theta = self.n_theta @@ -204,7 +212,9 @@ def fijhat_k(k, F): self.C = C def complex2real(self, ell): - # compute transformation matrices that convert complex representations to real ones + """ + Construct the transformation matrices between complex and real degree ell representations. + """ diml = 2 * ell + 1 Tinv = np.zeros((diml, diml), dtype=complex_type(np.float64)) for i in range(diml): @@ -234,6 +244,9 @@ def complex2real(self, ell): ############# def perform_admm(self): + """ + Solve the symmetric NUG relaxation and optionally apply proximal refinement. + """ X_est = self.admm_sym_J(self.C, self.verbose) if self.perform_pr: @@ -249,6 +262,9 @@ def perform_admm(self): self.X_est = X_est def admm_sym_J(self, C, verbose): + """ + Solve the symmetry-constrained NUG semidefinite relaxation using ADMM. + """ Lmax = self.Lmax N = self.n_img max_iter = self.max_iter @@ -656,8 +672,9 @@ def print_updates(verbose): return X_admm def ADMM_preprocessing(self, C): - # compute necessary quantities for ADMM - # compute some useful index sets + """ + Construct the transformed coefficients, constraints, indices, and initial variables used by ADMM. + """ Lmax = self.Lmax N = self.n_img count = 0 @@ -820,6 +837,9 @@ def ADMM_preprocessing(self, C): ) def compute_fejer_weights(self): + """ + Evaluate the real Wigner representation blocks used by the discretized Fejer inequality constraints. + """ SO3_grid = self.discretize_SO3() Ngrid = SO3_grid.shape[0] start = 1 @@ -859,6 +879,9 @@ def permutek_block(Ak, k): return W0, W1, Ngrid def discretize_SO3(self): + """ + Construct an approximately uniform Euler-angle grid over SO(3). + """ S2 = saff_kuijlaars(self.S2_grid) S2_size = S2.shape[0] @@ -886,6 +909,9 @@ def discretize_SO3(self): ############################ def proximal_refine(self, X_admm, weight, Penalty, r): + """ + Refine the relaxed solution by iteratively encouraging lower-rank representation matrices. + """ N = self.n_img C = self.C @@ -959,6 +985,9 @@ def low_rank_proj(X, r_step): ######################### def euler_est(self): + """ + Recover rotations and Euler angles using the estimator for the configured symmetry group. + """ X_est = self.X_est if isinstance(self.sym_grp, CnSymmetryGroup): R_est, Euler_est = self.euler_est_Cm(X_est[0], X_est[self.n_sym - 1]) @@ -969,6 +998,9 @@ def euler_est(self): self.rotations = R_est.astype(self.dtype) def euler_est_Cm(self, X1, XS): + """ + Recover Euler angles from the degree-one and degree-m solutions for cyclic symmetry. + """ S = self.n_sym N = self.n_img sym_euler = np.zeros((S, 3), dtype=np.float64) @@ -1086,7 +1118,9 @@ def find_gamma(Xm, beta, alpha): return R_est, Euler_est def euler_est_Dm(self, X_est): - + """ + Recover Euler angles from the degree-two and degree-m solutions for dihedral symmetry. + """ X2 = X_est[1] S = self.sym_grp.order N = self.n_img @@ -1240,6 +1274,9 @@ def LS_D(W1, W2, W3, W4, Br, Bi): # Helper Functions # #################### def transform_coeff(self, A, Lmax, N, IDX_upper): + """ + Transform representation matrices into the two block-vector forms used by ADMM. + """ d0 = [0] d1 = [0] for k in range(1, Lmax + 1): @@ -1254,6 +1291,9 @@ def transform_coeff(self, A, Lmax, N, IDX_upper): return A0, A1 def transform_coeff_back(self, A0, A1, Lmax, N, IDX_upper, IDX_lower, idx_offdiag): + """ + Reconstruct representation matrices from the ADMM block-vector forms. + """ d0 = [0] d1 = [0] for k in range(1, Lmax + 1): @@ -1275,6 +1315,9 @@ def transform_coeff_back(self, A0, A1, Lmax, N, IDX_upper, IDX_lower, idx_offdia @staticmethod def permutek(Ak, k, N): + """ + Permute and split a degree-k matrix into blocks of sizes k and k + 1. + """ AkP = xp.copy(Ak) dk = 2 * k + 1 Pk = xp.eye(dk, dtype=AkP.dtype) @@ -1299,6 +1342,9 @@ def permutek(Ak, k, N): @staticmethod def permutek_back(Ak, k, N): + """ + Undo the degree-k block permutation and reconstruct the full matrix. + """ dk = 2 * k + 1 Pk = xp.eye(N * dk, dtype=Ak.dtype) idx = xp.concatenate((xp.arange(dk - k, dk), xp.arange(k + 1))) @@ -1322,11 +1368,17 @@ def permutek_back(Ak, k, N): @staticmethod def vec_block(A, N, sz, IDX_upper): + """ + Vectorize the upper-triangular image-pair blocks of a block matrix. + """ vecA = (A.reshape(N, sz, N, sz).transpose(0, 2, 3, 1)).reshape(N**2, sz**2).T return vecA[:, IDX_upper] @staticmethod def largest_eigenvalue(AI, Ngrid, N): + """ + Estimate the largest eigenvalue of the Fejér constraint operator. + """ # find the largest eigenvalue of the operator AI np.random.seed(0) z = xp.random.normal(0, 1, (Ngrid, N**2)) @@ -1341,6 +1393,9 @@ def largest_eigenvalue(AI, Ngrid, N): return Lambda def compute_rank(self, Lmax): + """ + Compute the ranks and matrices of the symmetry-averaging projectors at each degree. + """ rk = xp.zeros(Lmax, dtype=np.float64) A = [] for k in range(1, Lmax + 1): @@ -1351,6 +1406,9 @@ def compute_rank(self, Lmax): return rk, A def WD(self, J, euler): + """ + Evaluate degree-J Wigner D matrices at the supplied ZYZ Euler angles. + """ # compute Wigner D matrix alpha = euler[:, 0] beta = euler[:, 1] @@ -1366,6 +1424,9 @@ def WD(self, J, euler): @staticmethod def Wd(J, beta): + """ + Evaluate degree-J Wigner small-d matrices at the supplied polar angles. + """ # compute Wigner small d matrix d = np.zeros((len(beta), 2 * J + 1, 2 * J + 1), dtype=beta.dtype) for m in range(-J, J + 1): @@ -1393,6 +1454,9 @@ def Wd(J, beta): @staticmethod def mat_block(vecA, N, sz, IDX_upper, IDX_lower, idx_offdiag): + """ + Reconstruct a symmetric block matrix from its vectorized upper-triangular blocks. + """ tmp = vecA.T.reshape(N * (N + 1) // 2, sz, sz).transpose(0, 2, 1) AA = xp.zeros((N**2, sz, sz), dtype=vecA.dtype) AA[IDX_upper] = tmp @@ -1401,6 +1465,9 @@ def mat_block(vecA, N, sz, IDX_upper, IDX_lower, idx_offdiag): @staticmethod def psd_projection(B): + """ + Project one or more symmetric matrices onto the positive semidefinite cone. + """ # compute the PSD part of a symmstric matrix B_sym = (B + B.swapaxes(-1, -2)) / 2 evals, evecs = xp.linalg.eigh(B_sym) @@ -1409,6 +1476,9 @@ def psd_projection(B): @staticmethod def transform_block(A, k, Pk=None): + """ + Permute and vectorize the two invariant blocks of a degree-k matrix. + """ single = A.ndim == 2 if single: A = A[None, :, :] @@ -1432,6 +1502,9 @@ def transform_block(A, k, Pk=None): @staticmethod def transform_back_block(A0, A1, k, Pk=None): + """ + Reconstruct a degree-k matrix from its two invariant block vectors. + """ dk = 2 * k + 1 single = A0.ndim == 1 @@ -1454,7 +1527,7 @@ def transform_back_block(A0, A1, k, Pk=None): def construct_AEq(self): """ - Construct the linear equality matrix for quaternion constraints. + Construct the linear equality operator encoding the quaternion constraints. """ AEq = np.zeros((17, 21), np.float64) From d0f07b30a5406809f300a9350a4e3f9be6e71ef1 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 15 Jun 2026 11:53:29 -0400 Subject: [PATCH 36/38] Add param descrips for init --- src/aspire/abinitio/commonline_nug.py | 23 +++++++++++++++++++---- src/aspire/abinitio/commonline_utils.py | 2 +- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index 3fe18e2d84..94504f236b 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -45,10 +45,25 @@ def __init__( """ Initialize the symmetric NUG orientation estimator. - :param src: The source object of 2D denoised or class-averaged images with metadata - :param symmetry: A string, ie. 'C3', indicating the symmetry type. - :param n_rad: The number of points in the radial direction - :param n_theta: The number of points in the theta direction. Default = 360. + :param src: Source containing the input projection images. + :param symmetry: Cyclic or dihedral symmetry specification, such as 'C3' or + 'D4'. If omitted, uses the source symmetry. + :param n_rad: Number of radial samples in the polar Fourier transform. + :param n_theta: Number of angular samples in the polar Fourier transform. + :param max_shift: Maximum shift considered when comparing common lines. + :param shift_step: Sampling interval for candidate shifts. + :param mask: Whether to apply a circular mask to the input images. + :param Lmax: Maximum Wigner representation degree used in the relaxation. + :param T: Quadrature resolution used to compute the Fourier coefficients. + :param max_iter: Number of ADMM iterations. + :param rho: Initial ADMM penalty parameter. + :param ratio: Residual ratio used when updating the ADMM penalty. + :param factor: Scaling factor used when updating the ADMM penalty. + :param mult: Step-size multiplier for the ADMM primal update. + :param S2_grid: Number of sphere samples used to discretize SO(3). + :param Nstep_yI: Number of inequality-multiplier updates per ADMM iteration. + :param perform_pr: Whether to apply proximal refinement after ADMM. + :param verbose: Whether to log ADMM progress. """ super().__init__( diff --git a/src/aspire/abinitio/commonline_utils.py b/src/aspire/abinitio/commonline_utils.py index ec4c0cfbf7..283dfaddc8 100644 --- a/src/aspire/abinitio/commonline_utils.py +++ b/src/aspire/abinitio/commonline_utils.py @@ -468,7 +468,7 @@ def find_group_index(A): # # This lets us store arbitrary finite-group relative elements in a block # synchronization matrix. This is the key generalization beyond cyclic - # scalar phases in g_sync_Cn. + # scalar phases in g_sync_cyclic. reps = np.zeros((n_group, n_group, n_group), dtype=float) for a in range(n_group): From 100d2c5d1da441e3bcdb05539965ffdbf5d9ab28 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 15 Jun 2026 15:11:42 -0400 Subject: [PATCH 37/38] more params. cleanup. --- src/aspire/abinitio/commonline_nug.py | 120 +++++++++++++++++++------- 1 file changed, 88 insertions(+), 32 deletions(-) diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index 94504f236b..61bcaabdab 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -122,6 +122,8 @@ def _build_full_pft(self): def estimate_rotations(self): """ Estimate rotations by computing NUG coefficients, solving the SDP relaxation, and recovering Euler angles. + + :return: Estimated rotation matrices. """ self.compute_coeff() self.perform_admm() @@ -229,6 +231,10 @@ def fijhat_k(k, F): def complex2real(self, ell): """ Construct the transformation matrices between complex and real degree ell representations. + + :param ell: Wigner representation degree. + + :return: Forward and inverse change-of-basis matrices. """ diml = 2 * ell + 1 Tinv = np.zeros((diml, diml), dtype=complex_type(np.float64)) @@ -279,6 +285,11 @@ def perform_admm(self): def admm_sym_J(self, C, verbose): """ Solve the symmetry-constrained NUG semidefinite relaxation using ADMM. + + :param C: Fourier coefficient matrices of the NUG objective. + :param verbose: Whether to log ADMM progress. + + :return: Relaxed representation matrices. """ Lmax = self.Lmax N = self.n_img @@ -319,7 +330,7 @@ def admm_sym_J(self, C, verbose): n_pairs = N * (N - 1) // 2 Ngrid = self.Ngrid - rank_Ak, _ = self.compute_rank(Lmax) + rank_Ak, _ = self.compute_rank() logger.info(f"Rank of Ak: {rank_Ak}") # rank_Ak=cp.zeros(Lmax) @@ -544,8 +555,6 @@ def update_rho(X0, X1, Xd0, Xd1, Xq, bE, bEq, bI, res_X, rho, factor, normC): return rho, p_resnorm, d_resnorm def print_updates(verbose): - # X_admm=transform_coeff_back(X0,X1,Lmax,N); obj_p=0 - # for k in range(Lmax): obj_p+=xp.trace(C[k]@X_admm[k]) if verbose: obj_p = ( xp.vdot(C0[:, idx_diag], X0[:, idx_diag]) @@ -676,12 +685,7 @@ def print_updates(verbose): X0, X1, Xd0, Xd1, Xq, bE, bEq, bI, res_X, rho, factor, normC ) - X_admm = self.transform_coeff_back( - X0, X1, Lmax, N, IDX_upper, IDX_lower, idx_offdiag - ) - # if self.GPU: - # for k in range(Lmax): - # X_admm[k] = X_admm[k].get() + X_admm = self.transform_coeff_back(X0, X1, IDX_upper, IDX_lower, idx_offdiag) for k in range(Lmax): X_admm[k] = xp.asnumpy(X_admm[k]) return X_admm @@ -689,6 +693,10 @@ def print_updates(verbose): def ADMM_preprocessing(self, C): """ Construct the transformed coefficients, constraints, indices, and initial variables used by ADMM. + + :param C: Fourier coefficient matrices of the NUG objective. + + :return: Quantities required by the ADMM solver. """ Lmax = self.Lmax N = self.n_img @@ -724,7 +732,7 @@ def ADMM_preprocessing(self, C): Xnorm = np.sqrt(Xnorm) for k in range(Lmax): C[k] = xp.asarray(Xnorm / Cnorm * C[k]) - C0, C1 = self.transform_coeff(C, Lmax, N, IDX_upper) + C0, C1 = self.transform_coeff(C, IDX_upper) normC = np.sqrt(np.linalg.norm(C0) ** 2 + np.linalg.norm(C1) ** 2) del C @@ -812,18 +820,16 @@ def ADMM_preprocessing(self, C): for k in range(1, Lmax + 1): dk = 2 * k + 1 II.append(xp.eye(N * dk, dtype=np.float64)) - I0, I1 = self.transform_coeff(II, Lmax, N, IDX_upper) + I0, I1 = self.transform_coeff(II, IDX_upper) X0 = xp.zeros((D0, N * (N + 1) // 2), dtype=np.float64) X1 = xp.zeros((D1, N * (N + 1) // 2), dtype=np.float64) Xq = xp.zeros((16, N * (N - 1) // 2), dtype=np.float64) - # X0,X1=transform_coeff(II,Lmax,N); Xq=xp.zeros((16,N*(N-1))); Xq[0,:]=1; S0 = xp.copy(I0) S1 = xp.copy(I1) Sq = xp.zeros(Xq.shape, dtype=np.float64) - # S0=xp.zeros(X0.shape); S1=xp.zeros(X1.shape); Sq=xp.zeros(Xq.shape) self.Ngrid = Ngrid - # return C0,C1,normC,AEq,bEq,AEqAEqtinv,AI_mat,bI,Lambda,d0,d1,D0,D1,idx_diag,idx_offdiag,IDX_upper,IDX_lower,X0,X1,Xq,S0,S1,Sq + return ( C0, C1, @@ -854,6 +860,8 @@ def ADMM_preprocessing(self, C): def compute_fejer_weights(self): """ Evaluate the real Wigner representation blocks used by the discretized Fejer inequality constraints. + + :return: Two sets of block weights and the number of SO(3) grid points. """ SO3_grid = self.discretize_SO3() Ngrid = SO3_grid.shape[0] @@ -896,6 +904,8 @@ def permutek_block(Ak, k): def discretize_SO3(self): """ Construct an approximately uniform Euler-angle grid over SO(3). + + :return: Array of ZYZ Euler angles. """ S2 = saff_kuijlaars(self.S2_grid) S2_size = S2.shape[0] @@ -926,6 +936,13 @@ def discretize_SO3(self): def proximal_refine(self, X_admm, weight, Penalty, r): """ Refine the relaxed solution by iteratively encouraging lower-rank representation matrices. + + :param X_admm: Initial relaxed representation matrices. + :param weight: Degree-dependent refinement weights. + :param Penalty: Penalty value for each refinement step. + :param r: Rank offset for each refinement step. + + :return: Refined representation matrices. """ N = self.n_img C = self.C @@ -1014,7 +1031,12 @@ def euler_est(self): def euler_est_Cm(self, X1, XS): """ - Recover Euler angles from the degree-one and degree-m solutions for cyclic symmetry. + Recover Euler angles for cyclic symmetry. + + :param X1: Relaxed degree-one representation matrix. + :param XS: Relaxed representation matrix at the symmetry order. + + :return: Estimated rotation matrices and Euler angles. """ S = self.n_sym N = self.n_img @@ -1134,7 +1156,11 @@ def find_gamma(Xm, beta, alpha): def euler_est_Dm(self, X_est): """ - Recover Euler angles from the degree-two and degree-m solutions for dihedral symmetry. + Recover Euler angles for dihedral symmetry. + + :param X_est: Relaxed representation matrices. + + :return: Estimated rotation matrices and Euler angles. """ X2 = X_est[1] S = self.sym_grp.order @@ -1288,34 +1314,48 @@ def LS_D(W1, W2, W3, W4, Br, Bi): #################### # Helper Functions # #################### - def transform_coeff(self, A, Lmax, N, IDX_upper): + def transform_coeff(self, A, IDX_upper): """ - Transform representation matrices into the two block-vector forms used by ADMM. + Convert representation matrices to the block-vector form used by ADMM. + + :param A: Representation matrices. + :param IDX_upper: Indices of upper-triangular image pair blocks. + + :return: The two block-vector coefficient arrays. """ d0 = [0] d1 = [0] - for k in range(1, Lmax + 1): + for k in range(1, self.Lmax + 1): d0.append(d0[-1] + k**2) d1.append(d1[-1] + (k + 1) ** 2) - A0 = xp.zeros((d0[-1], N * (N + 1) // 2), dtype=np.float64) - A1 = xp.zeros((d1[-1], N * (N + 1) // 2), dtype=np.float64) - for k in range(1, Lmax + 1): - a0, a1 = self.permutek(A[k - 1], k, N) - A0[d0[k - 1] : d0[k], :] = self.vec_block(a0, N, k, IDX_upper) - A1[d1[k - 1] : d1[k], :] = self.vec_block(a1, N, k + 1, IDX_upper) + A0 = xp.zeros((d0[-1], self.n_img * (self.n_img + 1) // 2), dtype=np.float64) + A1 = xp.zeros((d1[-1], self.n_img * (self.n_img + 1) // 2), dtype=np.float64) + for k in range(1, self.Lmax + 1): + a0, a1 = self.permutek(A[k - 1], k, self.n_img) + A0[d0[k - 1] : d0[k], :] = self.vec_block(a0, self.n_img, k, IDX_upper) + A1[d1[k - 1] : d1[k], :] = self.vec_block(a1, self.n_img, k + 1, IDX_upper) return A0, A1 - def transform_coeff_back(self, A0, A1, Lmax, N, IDX_upper, IDX_lower, idx_offdiag): + def transform_coeff_back(self, A0, A1, IDX_upper, IDX_lower, idx_offdiag): """ - Reconstruct representation matrices from the ADMM block-vector forms. + Reconstruct representation matrices from their ADMM block-vector form. + + :param A0: First block-vector array. + :param A1: Second block-vector array. + :param IDX_upper: Indices of upper-triangular image-pair blocks. + :param IDX_lower: Indices of lower-triangular image-pair blocks. + :param idx_offdiag: Indices of off-diagonal image pairs. + + :return: Reconstructed representation matrices. """ d0 = [0] d1 = [0] - for k in range(1, Lmax + 1): + N = self.n_img + for k in range(1, self.Lmax + 1): d0.append(d0[-1] + k**2) d1.append(d1[-1] + (k + 1) ** 2) A = [] - for k in range(1, Lmax + 1): + for k in range(1, self.Lmax + 1): dk = 2 * k + 1 Ak = xp.zeros((N * dk, N * dk), dtype=np.float64) Ak[: N * k, : N * k] = self.mat_block( @@ -1332,6 +1372,12 @@ def transform_coeff_back(self, A0, A1, Lmax, N, IDX_upper, IDX_lower, idx_offdia def permutek(Ak, k, N): """ Permute and split a degree-k matrix into blocks of sizes k and k + 1. + + :param Ak: Degree-k block matrix. + :param k: Representation degree. + :param N: Number of images. + + :return: The two permuted matrix blocks. """ AkP = xp.copy(Ak) dk = 2 * k + 1 @@ -1359,6 +1405,12 @@ def permutek(Ak, k, N): def permutek_back(Ak, k, N): """ Undo the degree-k block permutation and reconstruct the full matrix. + + :param Ak: Permuted degree-k matrix. + :param k: Representation degree. + :param N: Number of images. + + :return: Matrix in the original block ordering. """ dk = 2 * k + 1 Pk = xp.eye(N * dk, dtype=Ak.dtype) @@ -1407,13 +1459,17 @@ def largest_eigenvalue(AI, Ngrid, N): logger.info("Largest eigenvalue of AIAIT is approximately %1.2f" % Lambda) return Lambda - def compute_rank(self, Lmax): + def compute_rank(self): """ Compute the ranks and matrices of the symmetry-averaging projectors at each degree. + + :param Lmax: Maximum representation degree. + + :return: Ranks and symmetry-averaging matrices for each degree. """ - rk = xp.zeros(Lmax, dtype=np.float64) + rk = xp.zeros(self.Lmax, dtype=np.float64) A = [] - for k in range(1, Lmax + 1): + for k in range(1, self.Lmax + 1): Ak = np.sum(self.WD(k, self.sym_euler), axis=0) Ak = np.round(Ak / self.n_sym, 6) A.append(Ak) From 22cde0d12db5e072761882d8f635856186f0f58c Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 15 Jun 2026 15:21:34 -0400 Subject: [PATCH 38/38] remove commented out code --- src/aspire/abinitio/commonline_nug.py | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/src/aspire/abinitio/commonline_nug.py b/src/aspire/abinitio/commonline_nug.py index 61bcaabdab..5d01f8cc77 100644 --- a/src/aspire/abinitio/commonline_nug.py +++ b/src/aspire/abinitio/commonline_nug.py @@ -333,9 +333,6 @@ def admm_sym_J(self, C, verbose): rank_Ak, _ = self.compute_rank() logger.info(f"Rank of Ak: {rank_Ak}") - # rank_Ak=cp.zeros(Lmax) - # for ell in range(Lmax): rank_Ak[ell]=np.linalg.matrix_rank(Ak(ell+1,sym_euler)) - AE = [] AEAETinv = [] for k in range(1, Lmax + 1): @@ -414,12 +411,9 @@ def fun_AI(X0, X1): tmp = xp.concatenate((X0, X1), axis=0) z[:, idx_diag] = AI_mat_diag @ tmp[:, idx_diag] z[:, idx_offdiag] = AI_mat_offdiag @ tmp[:, idx_offdiag] - # return AI_mat@xp.concatenate((X0,X1),axis=0) return z def fun_AIT(yI): - # Z=AI_mat.T@yI - # return Z[:D0,:], Z[D0:,:] Z = xp.zeros((D0 + D1, N * (N + 1) // 2), dtype=np.float64) Z[:, idx_diag] = AI_mat_diag.T @ yI[:, idx_diag] Z[:, idx_offdiag] = AI_mat_offdiag.T @ yI[:, idx_offdiag] @@ -454,7 +448,6 @@ def update_S(C0, C1, yE, yEq, yI, X0, X1, Xd0, Xd1, Xq, rho, Lmax, N): toc0 = time.perf_counter() Time[0] += toc0 - tic0 - # See if we can switch everything to C order before here. Sd0 = Sd0.T Sd1 = Sd1.T tic1 = time.perf_counter() @@ -590,7 +583,6 @@ def print_updates(verbose): idx_offdiag, ) res_psdX += np.linalg.norm(self.psd_projection(-tmp)) - # res_psdX+=norm(self.psd_projection(-tmp))/(1+norm(tmp)) tmp = self.mat_block( X1[d1[k - 1] : d1[k], :], N, @@ -600,7 +592,6 @@ def print_updates(verbose): idx_offdiag, ) res_psdX += np.linalg.norm(self.psd_projection(-tmp)) - # res_psdX+=norm(self.psd_projection(-tmp))/(1+norm(tmp)) res_psdX = res_psdX / (1 + np.linalg.norm(X0) + np.linalg.norm(X1)) res_psdD = 0 for k in range(1, Lmax + 1): @@ -613,13 +604,11 @@ def print_updates(verbose): res_psdD += np.linalg.norm( self.psd_projection(-tmp), axis=(-2, -1) ).sum() - # res_psdD+=norm(self.psd_projection(-tmp))/(1+norm(tmp)) res_psdD = res_psdD / (1 + np.linalg.norm(Xd0) + np.linalg.norm(Xd1)) res_psdQ = 0 for count in range(N * (N - 1) // 2): tmp = Xq[:, count].reshape(4, 4).T res_psdQ += np.linalg.norm(self.psd_projection(-tmp)) - # res_psdQ+=norm(self.psd_projection(-tmp))/(1+norm(tmp)) res_psdQ = res_psdQ / (1 + np.linalg.norm(Xq)) normS = np.sqrt( @@ -773,7 +762,7 @@ def ADMM_preprocessing(self, C): * (k + 0.5) * W1[k - 1][p].T.reshape(-1) ) - # this needs double checking + # this needs double checking (Ruiyi) AI_mat_offdiag[p, : d0[-1]] = w0 AI_mat_offdiag[p, d0[-1] :] = w1 @@ -805,7 +794,7 @@ def ADMM_preprocessing(self, C): * (k + 0.5) * (0.5 * W1[k - 1][p] + 0.5 * W1[k - 1][p].T).T.reshape(-1) ) - # this needs double checking + # this needs double checking (Ruiyi) AI_mat_diag[p, : d0[-1]] = w0 AI_mat_diag[p, d0[-1] :] = w1 AI_mat_diag = xp.asarray(AI_mat_diag) / 1 @@ -1093,10 +1082,10 @@ def find_alpha(X1): for j in range(N): z = X1[3 * i : 3 * (i + 1), 3 * j : 3 * (j + 1)][0, 0] ZZbar[i, j] = z / abs(z) - # ZZbar[i,j]=z*2/sin(beta[i])/sin(beta[j]) + z = X1[3 * i : 3 * (i + 1), 3 * j : 3 * (j + 1)][0, 2] ZZ[i, j] = -z / abs(z) - # ZZ[i,j]=-z*2/sin(beta[i])/sin(beta[j]) + evals, evecs = np.linalg.eigh(ZZbar) idx = np.argmax(abs(evals)) Z = evecs[:, idx] * np.sqrt(abs(evals[idx]))