feat: improve results basing on theory

This commit is contained in:
appellet 2025-05-25 14:41:39 +02:00
parent 0c19f7ce40
commit 545d4d7768
5 changed files with 89 additions and 20 deletions

View file

@ -1,28 +1,40 @@
# decoder.py
import numpy as np
from numpy import logaddexp
from utils import index_to_char
from codebook import construct_codebook
def decode_message(Y, codebook):
"""
ML decoding for unknown channel state s:
p(Y|i) = 0.5*p(Y|i,s=1) + 0.5*p(Y|i,s=2)
We use log-sum-exp to combine both branch metrics.
"""
G = 10
Y1, Y2 = Y[::2], Y[1::2]
scores = []
best_idx = None
best_metric = -np.inf
for i, c in enumerate(codebook):
# Only consider indices that map to characters
if i not in index_to_char:
continue
c1, c2 = c[::2], c[1::2]
s1 = np.sqrt(10) * np.dot(Y1, c1) + np.dot(Y2, c2)
s2 = np.dot(Y1, c1) + np.sqrt(10) * np.dot(Y2, c2)
scores.append(max(s1, s2))
return np.argmax(scores)
# Branch metrics (up to additive constants)
s1 = np.sqrt(G) * np.dot(Y1, c1) + np.dot(Y2, c2)
s2 = np.dot(Y1, c1) + np.sqrt(G) * np.dot(Y2, c2)
# Combine via log-sum-exp
metric = logaddexp(s1, s2)
if metric > best_metric:
best_metric = metric
best_idx = i
return best_idx
def signal_to_text(Y, codebook, r=6):
_, n, _, _ = construct_codebook(r, 1)
seg_len = n
# Reconstruct codebook length (seg_len)
_, seg_len, _, _ = construct_codebook(r, 1)
text = ''
for i in range(40):
seg = Y[i * seg_len:(i + 1) * seg_len]
index = decode_message(seg, codebook)
if index in index_to_char:
text += index_to_char[index]
else:
text += '?' # Unknown character
idx = decode_message(seg, codebook)
text += index_to_char.get(idx, '?')
return text

View file

@ -1,11 +1,14 @@
# encoder.py (unchanged except default r)
import numpy as np
from codebook import construct_codebook
from utils import char_to_index, normalize_energy
def text_to_signal(text, r=9, Eb=4):
def text_to_signal(text, r=5, Eb=3):
assert len(text) == 40, "Message must be exactly 40 characters."
codebook, n, m, alpha = construct_codebook(r, Eb)
# Map each character to its codeword
msg_indices = [char_to_index[c] for c in text]
signal = np.concatenate([codebook[i] for i in msg_indices])
# Enforce the energy constraint
signal = normalize_energy(signal, energy_limit=2000)
return signal, codebook

View file

@ -0,0 +1,54 @@
# evaluate_zero_error_ratio.py
import argparse
import numpy as np
from encoder import text_to_signal
from decoder import signal_to_text
from channel import channel
def zero_error_ratio(message: str, r: int, Eb: float, n_runs: int) -> float:
"""
Runs the encodechanneldecode pipeline n_runs times on the same message,
and returns the fraction of runs with 0 character errors.
"""
zero_count = 0
for _ in range(n_runs):
# encode
x, codebook = text_to_signal(message, r=r, Eb=Eb)
# pass through channel
y = channel(x)
# decode
decoded = signal_to_text(y, codebook, r=r)
# count errors
errors = sum(1 for a, b in zip(message, decoded) if a != b)
if errors == 0:
zero_count += 1
return zero_count / n_runs
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Estimate the zero-error run ratio over n trials."
)
parser.add_argument(
"--message", type=str, required=True,
help="The 40-char message to test."
)
parser.add_argument(
"--r", type=int, default=6,
help="Codebook parameter r (default: 6)."
)
parser.add_argument(
"--Eb", type=float, default=3.0,
help="Energy per bit Eb (default: 3)."
)
parser.add_argument(
"--n", type=int, default=1000,
help="Number of runs (default: 1000)."
)
args = parser.parse_args()
if len(args.message) != 40:
raise ValueError("Message must be exactly 40 characters long.")
ratio = zero_error_ratio(args.message, args.r, args.Eb, args.n)
print(f"Zero-error runs: {ratio:.3%} ({ratio:.4f} of {args.n})")

View file

@ -3,10 +3,10 @@ from decoder import signal_to_text
from channel import channel
def test_local():
message = "This is the end of the PDC course. Good "
x, codebook = text_to_signal(message, r=9, Eb=5)
message = "HelloWorld123 ThisIsATestMessage12345678"
x, codebook = text_to_signal(message, r=6, Eb=3)
y = channel(x)
decoded = signal_to_text(y, codebook, r=9)
decoded = signal_to_text(y, codebook, r=6)
print(f"Original: {message}")
print(f"Decoded : {decoded}")
@ -14,4 +14,4 @@ def test_local():
print(f"Character errors: {errors}/40")
if __name__ == "__main__":
test_local()
test_local()

View file

@ -5,7 +5,7 @@ from decoder import signal_to_text
def test_server():
message = "HelloWorld123 ThisIsATestMessage12345678"
x, codebook = text_to_signal(message, r=6, Eb=1)
x, codebook = text_to_signal(message, r=6, Eb=3)
np.savetxt("input.txt", x, fmt="%.10f")
subprocess.run([