import numpy as np from encoder import ( G, CHAR_SET, pair_to_index, index_to_pair, make_codebook ) print(np.__config__.show()) def decode_blocks_with_state(Y: np.ndarray, C: np.ndarray): n_cw = C.shape[1] * 2 # 8192 samples after duplication nb = Y.size // n_cw assert nb * n_cw == Y.size, "length mismatch" sqrtg = np.sqrt(G) # --- Vectorize even/odd block extraction done for speedup because nupy is dumb Y_blocks = Y.reshape(nb, n_cw) Y_even = Y_blocks[:, ::2] # shape (nb, 4096) Y_odd = Y_blocks[:, 1::2] # shape (nb, 4096) #print(f"Extracted {nb} blocks of {n_cw} samples each, ") # --- Vectorize all block scoring # Each: (nb, 4096) = (nb, 4096) @ (4096, 4096).T #print(Y_blocks.dtype, C.dtype, sqrtg.dtype, Y_even.dtype, Y_odd.dtype) Y_even = Y_even.astype(np.float32, copy=False) Y_odd = Y_odd.astype(np.float32, copy=False) C = C.astype(np.float32, copy=False) s1 = sqrtg * (Y_even @ C.T) + (Y_odd @ C.T) s2 = (Y_even @ C.T) + sqrtg * (Y_odd @ C.T) #print(f"Scoring {nb} blocks with {C.shape[0]} codewords, each of length {C.shape[1]}.") best_if_s1 = np.argmax(s1, axis=1) best_if_s2 = np.argmax(s2, axis=1) tot1 = np.sum(np.max(s1, axis=1)) tot2 = np.sum(np.max(s2, axis=1)) state = 1 if tot1 >= tot2 else 2 indices = best_if_s1 if state == 1 else best_if_s2 chars = [index_to_pair(i) for i in indices] decoded = ''.join(a+b for a, b in chars) return decoded, state def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str: return decode_blocks_with_state(Y, C)[0]