feat: get 67% correctness

This commit is contained in:
appellet 2025-05-25 16:35:41 +02:00
parent df631a2199
commit 0c1737647b
4 changed files with 157 additions and 76 deletions

View file

@ -1,40 +1,32 @@
# decoder.py
import numpy as np
from numpy import logaddexp
from utils import index_to_char
from codebook import construct_codebook
from encoder import ALPHABET, G
def decode_message(Y, codebook):
def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str:
"""
ML decoding for unknown channel state s:
p(Y|i) = 0.5*p(Y|i,s=1) + 0.5*p(Y|i,s=2)
We use log-sum-exp to combine both branch metrics.
Decode received samples by maximum correlation score
"""
G = 10
Y1, Y2 = Y[::2], Y[1::2]
best_idx = None
best_metric = -np.inf
for i, c in enumerate(codebook):
# Only consider indices that map to characters
if i not in index_to_char:
continue
c1, c2 = c[::2], c[1::2]
# Branch metrics (up to additive constants)
s1 = np.sqrt(G) * np.dot(Y1, c1) + np.dot(Y2, c2)
s2 = np.dot(Y1, c1) + np.sqrt(G) * np.dot(Y2, c2)
# Combine via log-sum-exp
metric = logaddexp(s1, s2)
if metric > best_metric:
best_metric = metric
best_idx = i
return best_idx
n = C.shape[1]
half = n // 2
num = Y.size // n
C1 = C[:, :half]
C2 = C[:, half:]
sqrtG = np.sqrt(G)
recovered = []
for k in range(num):
Yb = Y[k*n:(k+1)*n]
Ye, Yo = Yb[0::2], Yb[1::2]
s1 = sqrtG * (Ye @ C1.T) + (Yo @ C2.T)
s2 = (Ye @ C1.T) + sqrtG * (Yo @ C2.T)
score = np.maximum(s1, s2)
best = int(np.argmax(score))
recovered.append(ALPHABET[best])
return "".join(recovered)
def signal_to_text(Y, codebook, r=6):
# Reconstruct codebook length (seg_len)
_, seg_len, _, _ = construct_codebook(r, 1)
text = ''
for i in range(40):
seg = Y[i * seg_len:(i + 1) * seg_len]
idx = decode_message(seg, codebook)
text += index_to_char.get(idx, '?')
return text
def count_errors(orig: str, est: str):
"""
List mismatches between orig and est
"""
return [(i, o, e) for i, (o, e) in enumerate(zip(orig, est)) if o != e]

View file

@ -1,14 +1,56 @@
# encoder.py (unchanged except default r)
# encoder.py
import numpy as np
from codebook import construct_codebook
from utils import char_to_index, normalize_energy
def text_to_signal(text, r=5, Eb=3):
assert len(text) == 40, "Message must be exactly 40 characters."
codebook, n, m, alpha = construct_codebook(r, Eb)
# Map each character to its codeword
msg_indices = [char_to_index[c] for c in text]
signal = np.concatenate([codebook[i] for i in msg_indices])
# Enforce the energy constraint
signal = normalize_energy(signal, energy_limit=2000)
return signal, codebook
# System parameters
G = 10.0 # power gain for even samples
ENERGY_LIMIT = 2000.0 # total energy per block
ALPHABET = (
[chr(i) for i in range(ord('a'), ord('z')+1)] +
[chr(i) for i in range(ord('A'), ord('Z')+1)] +
[str(i) for i in range(10)] +
[' ', '.']
)
assert len(ALPHABET) == 64, "Alphabet must be size 64"
def hadamard(r: int) -> np.ndarray:
if r == 0:
return np.array([[1]], dtype=float)
M = hadamard(r-1)
return np.block([[M, M], [M, -M]])
def Br(r: int) -> np.ndarray:
M = hadamard(r)
return np.vstack([M, -M])
def make_codebook(r: int, num_blocks: int) -> np.ndarray:
"""
Build 64x64 codebook and scale blocks so energy per block ENERGY_LIMIT/num_blocks
"""
B = Br(r)
C = np.hstack([B, B]).astype(float)
raw_norm = np.sum(C[0]**2)
margin = 0.95
alpha = margin * (ENERGY_LIMIT / num_blocks) / raw_norm
return np.sqrt(alpha) * C
def interleave(c: np.ndarray) -> np.ndarray:
half = c.size // 2
x = np.empty(c.size)
x[0::2] = c[:half]
x[1::2] = c[half:]
return x
def encode_message(msg: str, C: np.ndarray) -> np.ndarray:
"""
Encode 40-character message into interleaved code symbols
"""
if len(msg) != 40:
raise ValueError("Message must be exactly 40 characters.")
idx = [ALPHABET.index(ch) for ch in msg]
blocks = [interleave(C[i]) for i in idx]
return np.concatenate(blocks)

View file

@ -1,17 +1,36 @@
from encoder import text_to_signal
from decoder import signal_to_text
# test_local.py
#!/usr/bin/env python3
import argparse
import numpy as np
from encoder import make_codebook, encode_message
from decoder import decode_blocks, count_errors
from channel import channel
def test_local():
message = "HelloWorld123 ThisIsATestMessage12345678"
x, codebook = text_to_signal(message, r=6, Eb=3)
y = channel(x)
decoded = signal_to_text(y, codebook, r=6)
print(f"Original: {message}")
print(f"Decoded : {decoded}")
errors = sum(1 for a, b in zip(message, decoded) if a != b)
print(f"Character errors: {errors}/40")
def main():
parser = argparse.ArgumentParser(description="Local test using channel.py")
parser.add_argument("--message", required=True, help="40-character message")
args = parser.parse_args()
msg = args.message
if len(msg) != 40:
raise ValueError("Message must be exactly 40 characters.")
C = make_codebook(r=5, num_blocks=len(msg))
x = encode_message(msg, C)
print(f"→ Total samples = {x.size}, total energy = {np.sum(x*x):.1f}")
Y = channel(x)
msg_hat = decode_blocks(Y, C)
print(f"↓ Decoded message: {msg_hat}")
errors = count_errors(msg, msg_hat)
print(f"Errors: {len(errors)} / {len(msg)} characters differ")
if errors:
for i, o, e in errors:
print(f" Pos {i}: sent '{o}' but got '{e}'")
else:
print("✔️ No decoding errors!")
if __name__ == "__main__":
test_local()
main()

View file

@ -1,28 +1,56 @@
# test_server.py
#!/usr/bin/env python3
import argparse
import subprocess
import numpy as np
from encoder import text_to_signal
from decoder import signal_to_text
from encoder import make_codebook, encode_message
from decoder import decode_blocks, count_errors
def test_server():
message = "HelloWorld123 ThisIsATestMessage12345678"
x, codebook = text_to_signal(message, r=6, Eb=3)
np.savetxt("input.txt", x, fmt="%.10f")
def call_client(input_path, output_path, host, port):
subprocess.run([
"python3", "client.py",
"--input_file", "input.txt",
"--output_file", "output.txt",
"--srv_hostname", "iscsrv72.epfl.ch",
"--srv_port", "80"
])
f"--input_file={input_path}",
f"--output_file={output_path}",
f"--srv_hostname={host}",
f"--srv_port={port}"
], check=True)
y = np.loadtxt("output.txt")
decoded = signal_to_text(y, codebook, r=6)
print(f"Original: {message}")
print(f"Decoded : {decoded}")
errors = sum(1 for a, b in zip(message, decoded) if a != b)
print(f"Character errors: {errors}/40")
def main():
parser = argparse.ArgumentParser(description="Server test using client.py")
parser.add_argument("--message", required=True, help="40-character message to send")
parser.add_argument("--srv_hostname", default="iscsrv72.epfl.ch", help="Server hostname")
parser.add_argument("--srv_port", type=int, default=80, help="Server port")
args = parser.parse_args()
msg = args.message
if len(msg) != 40:
raise ValueError("Message must be exactly 40 characters.")
C = make_codebook(r=5, num_blocks=len(msg))
x = encode_message(msg, C)
# write encoded symbols to fixed input.txt
input_file = "input.txt"
output_file = "output.txt"
np.savetxt(input_file, x)
# run client.py to read input.txt and write output.txt
call_client(input_file, output_file, args.srv_hostname, args.srv_port)
# read received samples
Y = np.loadtxt(output_file)
msg_hat = decode_blocks(Y, C)
print(f"↓ Decoded message: {msg_hat}")
errors = count_errors(msg, msg_hat)
print(f"Errors: {len(errors)} / {len(msg)} characters differ")
if errors:
for i, o, e in errors:
print(f" Pos {i}: sent '{o}' but got '{e}'")
else:
print("✔️ No decoding errors!")
if __name__ == "__main__":
test_server()
main()