fennol.utils.deconvolution

 1import numpy as np
 2import jax.numpy as jnp
 3import jax
 4from functools import partial
 5
 6
 7def kernel_lorentz(w, w0, gamma):
 8    sel = np.logical_and(np.abs(w) < 1.0e-10, np.abs(w0) < 1.0e-10)
 9    w2 = np.where(sel, 1.0, w**2)
10    w02 = np.where(sel, 1.0, w0**2)
11    return gamma * w2 / (np.pi * (w2 * gamma**2 + (w2 - w02) ** 2))
12
13
14def kernel_lorentz_pot(w, w0, gamma):
15    sel = np.logical_and(jnp.abs(w) < 1.0e-10, np.abs(w0) < 1.0e-10)
16    w2 = np.where(sel, 1.0, w**2)
17    w02 = np.where(sel, 1.0, w0**2)
18    return gamma * w2 / (np.pi * (w02 * gamma**2 + (w2 - w02) ** 2))
19
20
21def deconvolute_spectrum(
22    s_in,
23    omega,
24    gamma,
25    niteration=10,
26    kernel=kernel_lorentz,
27    trans=False,
28    symmetrize=True,
29    thr=1.0e-10,
30    verbose=False,
31    K_D=None,
32):
33    assert s_in.shape[0] == omega.shape[0], "s_in and omega must have the same length"
34    domega = omega[1] - omega[0]
35    if symmetrize:
36        nom_save = omega.shape[0]
37        s_in = np.concatenate((np.flip(s_in[1:], axis=0), s_in), axis=0)
38        omega = np.concatenate((-np.flip(omega[1:], axis=0), omega), axis=0)
39
40    if K_D is not None:
41        K, D = K_D
42        assert K.shape == (omega.shape[0],omega.shape[0]), "K and omega must have the same length"
43        assert D.shape == K.shape, "D and K must have the same shape"
44    else:
45        if verbose:
46            print("deconvolution started.")
47            print("  computing kernel matrix...")
48        nom = omega.shape[0]
49        omij0, omij1 = np.meshgrid(omega, omega)
50        omij0 = omij0.flatten(order="F")
51        omij1 = omij1.flatten(order="F")
52        K = kernel(omij0, omij1, gamma).reshape(nom, nom)
53        if trans:
54            omnorm = np.arange(nom) * domega
55            omnorm = np.concatenate((-np.flip(omnorm[1:], axis=0), omnorm), axis=0)
56            omnormij0, omnormij1 = np.meshgrid(omega, omnorm)
57            omnormij0 = omnormij0.flatten(order="F")
58            omnormij1 = omnormij1.flatten(order="F")
59            # print(omnormij.shape)
60            # omnormij=torch.cartesian_prod(omega,omnorm)
61            Knorm = kernel(omnormij0, omnormij1, gamma).reshape(nom, 2 * nom - 1)
62            K = K / np.sum(Knorm, axis=1)[:, None] / domega
63            del Knorm, omnormij0, omnormij1, omnorm
64        else:
65            K = K / np.sum(K, axis=0)[None, :] / domega
66
67        if verbose:
68            print("  kernel matrix computed.")
69            print("  computing double convolution matrix...")
70
71        D = (K.T @ K) * domega
72
73        if verbose:
74            print("  double convolution matrix computed.")
75            print("  starting iterations.")
76
77    s_out = s_in
78    h = (K.T @ s_in) * domega
79    for i in range(niteration):
80        if verbose:
81            print(f"  iteration {i+1}/{niteration}")
82        den = (s_out[None, :] * D).sum(axis=1) * domega
83        s_next = s_out * h / den
84        diff = ((s_next - s_out) ** 2).sum() / (s_out**2).sum()
85        if verbose:
86           print(f"    relative difference: {diff}")
87        s_out = s_next
88        if diff < thr:
89           break
90
91    if verbose:
92        print("deconvolution finished.")
93    s_rec = (K @ s_out) * domega
94    if symmetrize:
95        s_rec = s_rec[nom_save - 1 :]
96        s_out = s_out[nom_save - 1 :]
97
98    return s_out, s_rec, (K, D)
def kernel_lorentz(w, w0, gamma):
 8def kernel_lorentz(w, w0, gamma):
 9    sel = np.logical_and(np.abs(w) < 1.0e-10, np.abs(w0) < 1.0e-10)
10    w2 = np.where(sel, 1.0, w**2)
11    w02 = np.where(sel, 1.0, w0**2)
12    return gamma * w2 / (np.pi * (w2 * gamma**2 + (w2 - w02) ** 2))
def kernel_lorentz_pot(w, w0, gamma):
15def kernel_lorentz_pot(w, w0, gamma):
16    sel = np.logical_and(jnp.abs(w) < 1.0e-10, np.abs(w0) < 1.0e-10)
17    w2 = np.where(sel, 1.0, w**2)
18    w02 = np.where(sel, 1.0, w0**2)
19    return gamma * w2 / (np.pi * (w02 * gamma**2 + (w2 - w02) ** 2))
def deconvolute_spectrum( s_in, omega, gamma, niteration=10, kernel=<function kernel_lorentz>, trans=False, symmetrize=True, thr=1e-10, verbose=False, K_D=None):
22def deconvolute_spectrum(
23    s_in,
24    omega,
25    gamma,
26    niteration=10,
27    kernel=kernel_lorentz,
28    trans=False,
29    symmetrize=True,
30    thr=1.0e-10,
31    verbose=False,
32    K_D=None,
33):
34    assert s_in.shape[0] == omega.shape[0], "s_in and omega must have the same length"
35    domega = omega[1] - omega[0]
36    if symmetrize:
37        nom_save = omega.shape[0]
38        s_in = np.concatenate((np.flip(s_in[1:], axis=0), s_in), axis=0)
39        omega = np.concatenate((-np.flip(omega[1:], axis=0), omega), axis=0)
40
41    if K_D is not None:
42        K, D = K_D
43        assert K.shape == (omega.shape[0],omega.shape[0]), "K and omega must have the same length"
44        assert D.shape == K.shape, "D and K must have the same shape"
45    else:
46        if verbose:
47            print("deconvolution started.")
48            print("  computing kernel matrix...")
49        nom = omega.shape[0]
50        omij0, omij1 = np.meshgrid(omega, omega)
51        omij0 = omij0.flatten(order="F")
52        omij1 = omij1.flatten(order="F")
53        K = kernel(omij0, omij1, gamma).reshape(nom, nom)
54        if trans:
55            omnorm = np.arange(nom) * domega
56            omnorm = np.concatenate((-np.flip(omnorm[1:], axis=0), omnorm), axis=0)
57            omnormij0, omnormij1 = np.meshgrid(omega, omnorm)
58            omnormij0 = omnormij0.flatten(order="F")
59            omnormij1 = omnormij1.flatten(order="F")
60            # print(omnormij.shape)
61            # omnormij=torch.cartesian_prod(omega,omnorm)
62            Knorm = kernel(omnormij0, omnormij1, gamma).reshape(nom, 2 * nom - 1)
63            K = K / np.sum(Knorm, axis=1)[:, None] / domega
64            del Knorm, omnormij0, omnormij1, omnorm
65        else:
66            K = K / np.sum(K, axis=0)[None, :] / domega
67
68        if verbose:
69            print("  kernel matrix computed.")
70            print("  computing double convolution matrix...")
71
72        D = (K.T @ K) * domega
73
74        if verbose:
75            print("  double convolution matrix computed.")
76            print("  starting iterations.")
77
78    s_out = s_in
79    h = (K.T @ s_in) * domega
80    for i in range(niteration):
81        if verbose:
82            print(f"  iteration {i+1}/{niteration}")
83        den = (s_out[None, :] * D).sum(axis=1) * domega
84        s_next = s_out * h / den
85        diff = ((s_next - s_out) ** 2).sum() / (s_out**2).sum()
86        if verbose:
87           print(f"    relative difference: {diff}")
88        s_out = s_next
89        if diff < thr:
90           break
91
92    if verbose:
93        print("deconvolution finished.")
94    s_rec = (K @ s_out) * domega
95    if symmetrize:
96        s_rec = s_rec[nom_save - 1 :]
97        s_out = s_out[nom_save - 1 :]
98
99    return s_out, s_rec, (K, D)