pdc_project/decoder.py
2025-05-30 00:44:20 +02:00

43 lines
1.6 KiB
Python

import numpy as np
from encoder import (
G, CHAR_SET, pair_to_index, index_to_pair, make_codebook
)
print(np.__config__.show())
def decode_blocks_with_state(Y: np.ndarray, C: np.ndarray):
n_cw = C.shape[1] * 2 # 8192 samples after duplication
nb = Y.size // n_cw
assert nb * n_cw == Y.size, "length mismatch"
sqrtg = np.sqrt(G)
# --- Vectorize even/odd block extraction done for speedup because nupy is dumb
Y_blocks = Y.reshape(nb, n_cw)
Y_even = Y_blocks[:, ::2] # shape (nb, 4096)
Y_odd = Y_blocks[:, 1::2] # shape (nb, 4096)
#print(f"Extracted {nb} blocks of {n_cw} samples each, ")
# --- Vectorize all block scoring
# Each: (nb, 4096) = (nb, 4096) @ (4096, 4096).T
#print(Y_blocks.dtype, C.dtype, sqrtg.dtype, Y_even.dtype, Y_odd.dtype)
Y_even = Y_even.astype(np.float32, copy=False)
Y_odd = Y_odd.astype(np.float32, copy=False)
C = C.astype(np.float32, copy=False)
s1 = sqrtg * (Y_even @ C.T) + (Y_odd @ C.T)
s2 = (Y_even @ C.T) + sqrtg * (Y_odd @ C.T)
#print(f"Scoring {nb} blocks with {C.shape[0]} codewords, each of length {C.shape[1]}.")
best_if_s1 = np.argmax(s1, axis=1)
best_if_s2 = np.argmax(s2, axis=1)
tot1 = np.sum(np.max(s1, axis=1))
tot2 = np.sum(np.max(s2, axis=1))
state = 1 if tot1 >= tot2 else 2
indices = best_if_s1 if state == 1 else best_if_s2
chars = [index_to_pair(i) for i in indices]
decoded = ''.join(a+b for a, b in chars)
return decoded, state
def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str:
return decode_blocks_with_state(Y, C)[0]