feat: improve results basing on theory
This commit is contained in:
parent
0c19f7ce40
commit
545d4d7768
5 changed files with 89 additions and 20 deletions
40
decoder.py
40
decoder.py
|
@ -1,28 +1,40 @@
|
|||
|
||||
# decoder.py
|
||||
import numpy as np
|
||||
from numpy import logaddexp
|
||||
from utils import index_to_char
|
||||
from codebook import construct_codebook
|
||||
|
||||
def decode_message(Y, codebook):
|
||||
"""
|
||||
ML decoding for unknown channel state s:
|
||||
p(Y|i) = 0.5*p(Y|i,s=1) + 0.5*p(Y|i,s=2)
|
||||
We use log-sum-exp to combine both branch metrics.
|
||||
"""
|
||||
G = 10
|
||||
Y1, Y2 = Y[::2], Y[1::2]
|
||||
scores = []
|
||||
best_idx = None
|
||||
best_metric = -np.inf
|
||||
for i, c in enumerate(codebook):
|
||||
# Only consider indices that map to characters
|
||||
if i not in index_to_char:
|
||||
continue
|
||||
c1, c2 = c[::2], c[1::2]
|
||||
s1 = np.sqrt(10) * np.dot(Y1, c1) + np.dot(Y2, c2)
|
||||
s2 = np.dot(Y1, c1) + np.sqrt(10) * np.dot(Y2, c2)
|
||||
scores.append(max(s1, s2))
|
||||
return np.argmax(scores)
|
||||
# Branch metrics (up to additive constants)
|
||||
s1 = np.sqrt(G) * np.dot(Y1, c1) + np.dot(Y2, c2)
|
||||
s2 = np.dot(Y1, c1) + np.sqrt(G) * np.dot(Y2, c2)
|
||||
# Combine via log-sum-exp
|
||||
metric = logaddexp(s1, s2)
|
||||
if metric > best_metric:
|
||||
best_metric = metric
|
||||
best_idx = i
|
||||
return best_idx
|
||||
|
||||
|
||||
def signal_to_text(Y, codebook, r=6):
|
||||
_, n, _, _ = construct_codebook(r, 1)
|
||||
seg_len = n
|
||||
# Reconstruct codebook length (seg_len)
|
||||
_, seg_len, _, _ = construct_codebook(r, 1)
|
||||
text = ''
|
||||
for i in range(40):
|
||||
seg = Y[i * seg_len:(i + 1) * seg_len]
|
||||
index = decode_message(seg, codebook)
|
||||
if index in index_to_char:
|
||||
text += index_to_char[index]
|
||||
else:
|
||||
text += '?' # Unknown character
|
||||
idx = decode_message(seg, codebook)
|
||||
text += index_to_char.get(idx, '?')
|
||||
return text
|
|
@ -1,11 +1,14 @@
|
|||
# encoder.py (unchanged except default r)
|
||||
import numpy as np
|
||||
from codebook import construct_codebook
|
||||
from utils import char_to_index, normalize_energy
|
||||
|
||||
def text_to_signal(text, r=9, Eb=4):
|
||||
def text_to_signal(text, r=5, Eb=3):
|
||||
assert len(text) == 40, "Message must be exactly 40 characters."
|
||||
codebook, n, m, alpha = construct_codebook(r, Eb)
|
||||
# Map each character to its codeword
|
||||
msg_indices = [char_to_index[c] for c in text]
|
||||
signal = np.concatenate([codebook[i] for i in msg_indices])
|
||||
# Enforce the energy constraint
|
||||
signal = normalize_energy(signal, energy_limit=2000)
|
||||
return signal, codebook
|
54
evaluate_zero_error_ratio.py
Normal file
54
evaluate_zero_error_ratio.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
# evaluate_zero_error_ratio.py
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
from encoder import text_to_signal
|
||||
from decoder import signal_to_text
|
||||
from channel import channel
|
||||
|
||||
def zero_error_ratio(message: str, r: int, Eb: float, n_runs: int) -> float:
|
||||
"""
|
||||
Runs the encode–channel–decode pipeline n_runs times on the same message,
|
||||
and returns the fraction of runs with 0 character errors.
|
||||
"""
|
||||
zero_count = 0
|
||||
for _ in range(n_runs):
|
||||
# encode
|
||||
x, codebook = text_to_signal(message, r=r, Eb=Eb)
|
||||
# pass through channel
|
||||
y = channel(x)
|
||||
# decode
|
||||
decoded = signal_to_text(y, codebook, r=r)
|
||||
# count errors
|
||||
errors = sum(1 for a, b in zip(message, decoded) if a != b)
|
||||
if errors == 0:
|
||||
zero_count += 1
|
||||
return zero_count / n_runs
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Estimate the zero-error run ratio over n trials."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--message", type=str, required=True,
|
||||
help="The 40-char message to test."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--r", type=int, default=6,
|
||||
help="Codebook parameter r (default: 6)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--Eb", type=float, default=3.0,
|
||||
help="Energy per bit Eb (default: 3)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n", type=int, default=1000,
|
||||
help="Number of runs (default: 1000)."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if len(args.message) != 40:
|
||||
raise ValueError("Message must be exactly 40 characters long.")
|
||||
|
||||
ratio = zero_error_ratio(args.message, args.r, args.Eb, args.n)
|
||||
print(f"Zero-error runs: {ratio:.3%} ({ratio:.4f} of {args.n})")
|
|
@ -3,10 +3,10 @@ from decoder import signal_to_text
|
|||
from channel import channel
|
||||
|
||||
def test_local():
|
||||
message = "This is the end of the PDC course. Good "
|
||||
x, codebook = text_to_signal(message, r=9, Eb=5)
|
||||
message = "HelloWorld123 ThisIsATestMessage12345678"
|
||||
x, codebook = text_to_signal(message, r=6, Eb=3)
|
||||
y = channel(x)
|
||||
decoded = signal_to_text(y, codebook, r=9)
|
||||
decoded = signal_to_text(y, codebook, r=6)
|
||||
|
||||
print(f"Original: {message}")
|
||||
print(f"Decoded : {decoded}")
|
||||
|
|
|
@ -5,7 +5,7 @@ from decoder import signal_to_text
|
|||
|
||||
def test_server():
|
||||
message = "HelloWorld123 ThisIsATestMessage12345678"
|
||||
x, codebook = text_to_signal(message, r=6, Eb=1)
|
||||
x, codebook = text_to_signal(message, r=6, Eb=3)
|
||||
np.savetxt("input.txt", x, fmt="%.10f")
|
||||
|
||||
subprocess.run([
|
||||
|
|
Loading…
Add table
Reference in a new issue