fennol.utils.deconvolution

 1import numpy as np
 2import jax.numpy as jnp
 3import jax
 4from functools import partial
 5
 6
 7def kernel_lorentz(w, w0, gamma):
 8    sel = np.logical_and(np.abs(w) < 1.0e-10, np.abs(w0) < 1.0e-10)
 9    w2 = np.where(sel, 1.0, w**2)
10    w02 = np.where(sel, 1.0, w0**2)
11    return gamma * w2 / (np.pi * (w2 * gamma**2 + (w2 - w02) ** 2))
12
13
14def kernel_lorentz_pot(w, w0, gamma):
15    sel = np.logical_and(jnp.abs(w) < 1.0e-10, np.abs(w0) < 1.0e-10)
16    w2 = np.where(sel, 1.0, w**2)
17    w02 = np.where(sel, 1.0, w0**2)
18    return gamma * w2 / (np.pi * (w02 * gamma**2 + (w2 - w02) ** 2))
19
20
21def deconvolute_spectrum(
22    s_in,
23    omega,
24    gamma,
25    niteration=10,
26    kernel=kernel_lorentz,
27    trans=False,
28    symmetrize=True,
29    thr=1.0e-10,
30    verbose=False,
31    K_D=None,
32):
33    assert s_in.shape[0] == omega.shape[0], "s_in and omega must have the same length"
34    domega = omega[1] - omega[0]
35    if symmetrize:
36        nom_save = omega.shape[0]
37        s_in = np.concatenate((np.flip(s_in[1:], axis=0), s_in), axis=0)
38        omega = np.concatenate((-np.flip(omega[1:], axis=0), omega), axis=0)
39
40    if K_D is not None:
41        K, D = K_D
42        assert K.shape == (omega.shape[0],omega.shape[0]), "K and omega must have the same length"
43        assert D.shape == K.shape, "D and K must have the same shape"
44    else:
45        if verbose:
46            print("deconvolution started.")
47            print("  computing kernel matrix...")
48        nom = omega.shape[0]
49        omij0, omij1 = np.meshgrid(omega, omega)
50        omij0 = omij0.flatten(order="F")
51        omij1 = omij1.flatten(order="F")
52        K = kernel(omij0, omij1, gamma).reshape(nom, nom)
53        if trans:
54            omnorm = np.arange(nom) * domega
55            omnorm = np.concatenate((-np.flip(omnorm[1:], axis=0), omnorm), axis=0)
56            omnormij0, omnormij1 = np.meshgrid(omega, omnorm)
57            omnormij0 = omnormij0.flatten(order="F")
58            omnormij1 = omnormij1.flatten(order="F")
59            Knorm = kernel(omnormij0, omnormij1, gamma).reshape(nom, 2 * nom - 1)
60            K = K / np.sum(Knorm, axis=1)[:, None] / domega
61            del Knorm, omnormij0, omnormij1, omnorm
62        else:
63            K = K / np.sum(K, axis=0)[None, :] / domega
64
65        if verbose:
66            print("  kernel matrix computed.")
67            print("  computing double convolution matrix...")
68
69        D = (K.T @ K) * domega
70
71        if verbose:
72            print("  double convolution matrix computed.")
73            print("  starting iterations.")
74
75    s_out = s_in
76    h = (K.T @ s_in) * domega
77    for i in range(niteration):
78        if verbose:
79            print(f"  iteration {i+1}/{niteration}")
80        den = (s_out[None, :] * D).sum(axis=1) * domega
81        s_next = s_out * h / den
82        diff = ((s_next - s_out) ** 2).sum() / (s_out**2).sum()
83        if verbose:
84           print(f"    relative difference: {diff}")
85        s_out = s_next
86        if diff < thr:
87           break
88
89    if verbose:
90        print("deconvolution finished.")
91    s_rec = (K @ s_out) * domega
92    if symmetrize:
93        s_rec = s_rec[nom_save - 1 :]
94        s_out = s_out[nom_save - 1 :]
95
96    return s_out, s_rec, (K, D)
def kernel_lorentz(w, w0, gamma):
 8def kernel_lorentz(w, w0, gamma):
 9    sel = np.logical_and(np.abs(w) < 1.0e-10, np.abs(w0) < 1.0e-10)
10    w2 = np.where(sel, 1.0, w**2)
11    w02 = np.where(sel, 1.0, w0**2)
12    return gamma * w2 / (np.pi * (w2 * gamma**2 + (w2 - w02) ** 2))
def kernel_lorentz_pot(w, w0, gamma):
15def kernel_lorentz_pot(w, w0, gamma):
16    sel = np.logical_and(jnp.abs(w) < 1.0e-10, np.abs(w0) < 1.0e-10)
17    w2 = np.where(sel, 1.0, w**2)
18    w02 = np.where(sel, 1.0, w0**2)
19    return gamma * w2 / (np.pi * (w02 * gamma**2 + (w2 - w02) ** 2))
def deconvolute_spectrum( s_in, omega, gamma, niteration=10, kernel=<function kernel_lorentz>, trans=False, symmetrize=True, thr=1e-10, verbose=False, K_D=None):
22def deconvolute_spectrum(
23    s_in,
24    omega,
25    gamma,
26    niteration=10,
27    kernel=kernel_lorentz,
28    trans=False,
29    symmetrize=True,
30    thr=1.0e-10,
31    verbose=False,
32    K_D=None,
33):
34    assert s_in.shape[0] == omega.shape[0], "s_in and omega must have the same length"
35    domega = omega[1] - omega[0]
36    if symmetrize:
37        nom_save = omega.shape[0]
38        s_in = np.concatenate((np.flip(s_in[1:], axis=0), s_in), axis=0)
39        omega = np.concatenate((-np.flip(omega[1:], axis=0), omega), axis=0)
40
41    if K_D is not None:
42        K, D = K_D
43        assert K.shape == (omega.shape[0],omega.shape[0]), "K and omega must have the same length"
44        assert D.shape == K.shape, "D and K must have the same shape"
45    else:
46        if verbose:
47            print("deconvolution started.")
48            print("  computing kernel matrix...")
49        nom = omega.shape[0]
50        omij0, omij1 = np.meshgrid(omega, omega)
51        omij0 = omij0.flatten(order="F")
52        omij1 = omij1.flatten(order="F")
53        K = kernel(omij0, omij1, gamma).reshape(nom, nom)
54        if trans:
55            omnorm = np.arange(nom) * domega
56            omnorm = np.concatenate((-np.flip(omnorm[1:], axis=0), omnorm), axis=0)
57            omnormij0, omnormij1 = np.meshgrid(omega, omnorm)
58            omnormij0 = omnormij0.flatten(order="F")
59            omnormij1 = omnormij1.flatten(order="F")
60            Knorm = kernel(omnormij0, omnormij1, gamma).reshape(nom, 2 * nom - 1)
61            K = K / np.sum(Knorm, axis=1)[:, None] / domega
62            del Knorm, omnormij0, omnormij1, omnorm
63        else:
64            K = K / np.sum(K, axis=0)[None, :] / domega
65
66        if verbose:
67            print("  kernel matrix computed.")
68            print("  computing double convolution matrix...")
69
70        D = (K.T @ K) * domega
71
72        if verbose:
73            print("  double convolution matrix computed.")
74            print("  starting iterations.")
75
76    s_out = s_in
77    h = (K.T @ s_in) * domega
78    for i in range(niteration):
79        if verbose:
80            print(f"  iteration {i+1}/{niteration}")
81        den = (s_out[None, :] * D).sum(axis=1) * domega
82        s_next = s_out * h / den
83        diff = ((s_next - s_out) ** 2).sum() / (s_out**2).sum()
84        if verbose:
85           print(f"    relative difference: {diff}")
86        s_out = s_next
87        if diff < thr:
88           break
89
90    if verbose:
91        print("deconvolution finished.")
92    s_rec = (K @ s_out) * domega
93    if symmetrize:
94        s_rec = s_rec[nom_save - 1 :]
95        s_out = s_out[nom_save - 1 :]
96
97    return s_out, s_rec, (K, D)