feat: get 67% correctness
This commit is contained in:
parent
df631a2199
commit
0c1737647b
4 changed files with 157 additions and 76 deletions
60
decoder.py
60
decoder.py
|
@ -1,40 +1,32 @@
|
||||||
|
# decoder.py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy import logaddexp
|
from encoder import ALPHABET, G
|
||||||
from utils import index_to_char
|
|
||||||
from codebook import construct_codebook
|
|
||||||
|
|
||||||
def decode_message(Y, codebook):
|
|
||||||
|
def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str:
|
||||||
"""
|
"""
|
||||||
ML decoding for unknown channel state s:
|
Decode received samples by maximum correlation score
|
||||||
p(Y|i) = 0.5*p(Y|i,s=1) + 0.5*p(Y|i,s=2)
|
|
||||||
We use log-sum-exp to combine both branch metrics.
|
|
||||||
"""
|
"""
|
||||||
G = 10
|
n = C.shape[1]
|
||||||
Y1, Y2 = Y[::2], Y[1::2]
|
half = n // 2
|
||||||
best_idx = None
|
num = Y.size // n
|
||||||
best_metric = -np.inf
|
C1 = C[:, :half]
|
||||||
for i, c in enumerate(codebook):
|
C2 = C[:, half:]
|
||||||
# Only consider indices that map to characters
|
sqrtG = np.sqrt(G)
|
||||||
if i not in index_to_char:
|
recovered = []
|
||||||
continue
|
for k in range(num):
|
||||||
c1, c2 = c[::2], c[1::2]
|
Yb = Y[k*n:(k+1)*n]
|
||||||
# Branch metrics (up to additive constants)
|
Ye, Yo = Yb[0::2], Yb[1::2]
|
||||||
s1 = np.sqrt(G) * np.dot(Y1, c1) + np.dot(Y2, c2)
|
s1 = sqrtG * (Ye @ C1.T) + (Yo @ C2.T)
|
||||||
s2 = np.dot(Y1, c1) + np.sqrt(G) * np.dot(Y2, c2)
|
s2 = (Ye @ C1.T) + sqrtG * (Yo @ C2.T)
|
||||||
# Combine via log-sum-exp
|
score = np.maximum(s1, s2)
|
||||||
metric = logaddexp(s1, s2)
|
best = int(np.argmax(score))
|
||||||
if metric > best_metric:
|
recovered.append(ALPHABET[best])
|
||||||
best_metric = metric
|
return "".join(recovered)
|
||||||
best_idx = i
|
|
||||||
return best_idx
|
|
||||||
|
|
||||||
|
|
||||||
def signal_to_text(Y, codebook, r=6):
|
def count_errors(orig: str, est: str):
|
||||||
# Reconstruct codebook length (seg_len)
|
"""
|
||||||
_, seg_len, _, _ = construct_codebook(r, 1)
|
List mismatches between orig and est
|
||||||
text = ''
|
"""
|
||||||
for i in range(40):
|
return [(i, o, e) for i, (o, e) in enumerate(zip(orig, est)) if o != e]
|
||||||
seg = Y[i * seg_len:(i + 1) * seg_len]
|
|
||||||
idx = decode_message(seg, codebook)
|
|
||||||
text += index_to_char.get(idx, '?')
|
|
||||||
return text
|
|
66
encoder.py
66
encoder.py
|
@ -1,14 +1,56 @@
|
||||||
# encoder.py (unchanged except default r)
|
# encoder.py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from codebook import construct_codebook
|
|
||||||
from utils import char_to_index, normalize_energy
|
|
||||||
|
|
||||||
def text_to_signal(text, r=5, Eb=3):
|
# System parameters
|
||||||
assert len(text) == 40, "Message must be exactly 40 characters."
|
G = 10.0 # power gain for even samples
|
||||||
codebook, n, m, alpha = construct_codebook(r, Eb)
|
ENERGY_LIMIT = 2000.0 # total energy per block
|
||||||
# Map each character to its codeword
|
ALPHABET = (
|
||||||
msg_indices = [char_to_index[c] for c in text]
|
[chr(i) for i in range(ord('a'), ord('z')+1)] +
|
||||||
signal = np.concatenate([codebook[i] for i in msg_indices])
|
[chr(i) for i in range(ord('A'), ord('Z')+1)] +
|
||||||
# Enforce the energy constraint
|
[str(i) for i in range(10)] +
|
||||||
signal = normalize_energy(signal, energy_limit=2000)
|
[' ', '.']
|
||||||
return signal, codebook
|
)
|
||||||
|
assert len(ALPHABET) == 64, "Alphabet must be size 64"
|
||||||
|
|
||||||
|
|
||||||
|
def hadamard(r: int) -> np.ndarray:
|
||||||
|
if r == 0:
|
||||||
|
return np.array([[1]], dtype=float)
|
||||||
|
M = hadamard(r-1)
|
||||||
|
return np.block([[M, M], [M, -M]])
|
||||||
|
|
||||||
|
|
||||||
|
def Br(r: int) -> np.ndarray:
|
||||||
|
M = hadamard(r)
|
||||||
|
return np.vstack([M, -M])
|
||||||
|
|
||||||
|
|
||||||
|
def make_codebook(r: int, num_blocks: int) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Build 64x64 codebook and scale blocks so energy per block ≤ ENERGY_LIMIT/num_blocks
|
||||||
|
"""
|
||||||
|
B = Br(r)
|
||||||
|
C = np.hstack([B, B]).astype(float)
|
||||||
|
raw_norm = np.sum(C[0]**2)
|
||||||
|
margin = 0.95
|
||||||
|
alpha = margin * (ENERGY_LIMIT / num_blocks) / raw_norm
|
||||||
|
return np.sqrt(alpha) * C
|
||||||
|
|
||||||
|
|
||||||
|
def interleave(c: np.ndarray) -> np.ndarray:
|
||||||
|
half = c.size // 2
|
||||||
|
x = np.empty(c.size)
|
||||||
|
x[0::2] = c[:half]
|
||||||
|
x[1::2] = c[half:]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def encode_message(msg: str, C: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Encode 40-character message into interleaved code symbols
|
||||||
|
"""
|
||||||
|
if len(msg) != 40:
|
||||||
|
raise ValueError("Message must be exactly 40 characters.")
|
||||||
|
idx = [ALPHABET.index(ch) for ch in msg]
|
||||||
|
blocks = [interleave(C[i]) for i in idx]
|
||||||
|
return np.concatenate(blocks)
|
|
@ -1,17 +1,36 @@
|
||||||
from encoder import text_to_signal
|
# test_local.py
|
||||||
from decoder import signal_to_text
|
#!/usr/bin/env python3
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from encoder import make_codebook, encode_message
|
||||||
|
from decoder import decode_blocks, count_errors
|
||||||
from channel import channel
|
from channel import channel
|
||||||
|
|
||||||
def test_local():
|
|
||||||
message = "HelloWorld123 ThisIsATestMessage12345678"
|
|
||||||
x, codebook = text_to_signal(message, r=6, Eb=3)
|
|
||||||
y = channel(x)
|
|
||||||
decoded = signal_to_text(y, codebook, r=6)
|
|
||||||
|
|
||||||
print(f"Original: {message}")
|
def main():
|
||||||
print(f"Decoded : {decoded}")
|
parser = argparse.ArgumentParser(description="Local test using channel.py")
|
||||||
errors = sum(1 for a, b in zip(message, decoded) if a != b)
|
parser.add_argument("--message", required=True, help="40-character message")
|
||||||
print(f"Character errors: {errors}/40")
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
msg = args.message
|
||||||
|
if len(msg) != 40:
|
||||||
|
raise ValueError("Message must be exactly 40 characters.")
|
||||||
|
C = make_codebook(r=5, num_blocks=len(msg))
|
||||||
|
x = encode_message(msg, C)
|
||||||
|
print(f"→ Total samples = {x.size}, total energy = {np.sum(x*x):.1f}")
|
||||||
|
|
||||||
|
Y = channel(x)
|
||||||
|
|
||||||
|
msg_hat = decode_blocks(Y, C)
|
||||||
|
print(f"↓ Decoded message: {msg_hat}")
|
||||||
|
|
||||||
|
errors = count_errors(msg, msg_hat)
|
||||||
|
print(f"Errors: {len(errors)} / {len(msg)} characters differ")
|
||||||
|
if errors:
|
||||||
|
for i, o, e in errors:
|
||||||
|
print(f" Pos {i}: sent '{o}' but got '{e}'")
|
||||||
|
else:
|
||||||
|
print("✔️ No decoding errors!")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_local()
|
main()
|
||||||
|
|
|
@ -1,28 +1,56 @@
|
||||||
|
# test_server.py
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
import argparse
|
||||||
import subprocess
|
import subprocess
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from encoder import text_to_signal
|
from encoder import make_codebook, encode_message
|
||||||
from decoder import signal_to_text
|
from decoder import decode_blocks, count_errors
|
||||||
|
|
||||||
def test_server():
|
|
||||||
message = "HelloWorld123 ThisIsATestMessage12345678"
|
|
||||||
x, codebook = text_to_signal(message, r=6, Eb=3)
|
|
||||||
np.savetxt("input.txt", x, fmt="%.10f")
|
|
||||||
|
|
||||||
|
def call_client(input_path, output_path, host, port):
|
||||||
subprocess.run([
|
subprocess.run([
|
||||||
"python3", "client.py",
|
"python3", "client.py",
|
||||||
"--input_file", "input.txt",
|
f"--input_file={input_path}",
|
||||||
"--output_file", "output.txt",
|
f"--output_file={output_path}",
|
||||||
"--srv_hostname", "iscsrv72.epfl.ch",
|
f"--srv_hostname={host}",
|
||||||
"--srv_port", "80"
|
f"--srv_port={port}"
|
||||||
])
|
], check=True)
|
||||||
|
|
||||||
y = np.loadtxt("output.txt")
|
|
||||||
decoded = signal_to_text(y, codebook, r=6)
|
|
||||||
|
|
||||||
print(f"Original: {message}")
|
def main():
|
||||||
print(f"Decoded : {decoded}")
|
parser = argparse.ArgumentParser(description="Server test using client.py")
|
||||||
errors = sum(1 for a, b in zip(message, decoded) if a != b)
|
parser.add_argument("--message", required=True, help="40-character message to send")
|
||||||
print(f"Character errors: {errors}/40")
|
parser.add_argument("--srv_hostname", default="iscsrv72.epfl.ch", help="Server hostname")
|
||||||
|
parser.add_argument("--srv_port", type=int, default=80, help="Server port")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
msg = args.message
|
||||||
|
if len(msg) != 40:
|
||||||
|
raise ValueError("Message must be exactly 40 characters.")
|
||||||
|
C = make_codebook(r=5, num_blocks=len(msg))
|
||||||
|
x = encode_message(msg, C)
|
||||||
|
|
||||||
|
# write encoded symbols to fixed input.txt
|
||||||
|
input_file = "input.txt"
|
||||||
|
output_file = "output.txt"
|
||||||
|
np.savetxt(input_file, x)
|
||||||
|
|
||||||
|
# run client.py to read input.txt and write output.txt
|
||||||
|
call_client(input_file, output_file, args.srv_hostname, args.srv_port)
|
||||||
|
|
||||||
|
# read received samples
|
||||||
|
Y = np.loadtxt(output_file)
|
||||||
|
|
||||||
|
msg_hat = decode_blocks(Y, C)
|
||||||
|
print(f"↓ Decoded message: {msg_hat}")
|
||||||
|
|
||||||
|
errors = count_errors(msg, msg_hat)
|
||||||
|
print(f"Errors: {len(errors)} / {len(msg)} characters differ")
|
||||||
|
if errors:
|
||||||
|
for i, o, e in errors:
|
||||||
|
print(f" Pos {i}: sent '{o}' but got '{e}'")
|
||||||
|
else:
|
||||||
|
print("✔️ No decoding errors!")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_server()
|
main()
|
Loading…
Add table
Reference in a new issue