feat: improve results basing on theory
This commit is contained in:
parent
0c19f7ce40
commit
545d4d7768
5 changed files with 89 additions and 20 deletions
40
decoder.py
40
decoder.py
|
@ -1,28 +1,40 @@
|
||||||
|
|
||||||
# decoder.py
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from numpy import logaddexp
|
||||||
from utils import index_to_char
|
from utils import index_to_char
|
||||||
from codebook import construct_codebook
|
from codebook import construct_codebook
|
||||||
|
|
||||||
def decode_message(Y, codebook):
|
def decode_message(Y, codebook):
|
||||||
|
"""
|
||||||
|
ML decoding for unknown channel state s:
|
||||||
|
p(Y|i) = 0.5*p(Y|i,s=1) + 0.5*p(Y|i,s=2)
|
||||||
|
We use log-sum-exp to combine both branch metrics.
|
||||||
|
"""
|
||||||
|
G = 10
|
||||||
Y1, Y2 = Y[::2], Y[1::2]
|
Y1, Y2 = Y[::2], Y[1::2]
|
||||||
scores = []
|
best_idx = None
|
||||||
|
best_metric = -np.inf
|
||||||
for i, c in enumerate(codebook):
|
for i, c in enumerate(codebook):
|
||||||
|
# Only consider indices that map to characters
|
||||||
|
if i not in index_to_char:
|
||||||
|
continue
|
||||||
c1, c2 = c[::2], c[1::2]
|
c1, c2 = c[::2], c[1::2]
|
||||||
s1 = np.sqrt(10) * np.dot(Y1, c1) + np.dot(Y2, c2)
|
# Branch metrics (up to additive constants)
|
||||||
s2 = np.dot(Y1, c1) + np.sqrt(10) * np.dot(Y2, c2)
|
s1 = np.sqrt(G) * np.dot(Y1, c1) + np.dot(Y2, c2)
|
||||||
scores.append(max(s1, s2))
|
s2 = np.dot(Y1, c1) + np.sqrt(G) * np.dot(Y2, c2)
|
||||||
return np.argmax(scores)
|
# Combine via log-sum-exp
|
||||||
|
metric = logaddexp(s1, s2)
|
||||||
|
if metric > best_metric:
|
||||||
|
best_metric = metric
|
||||||
|
best_idx = i
|
||||||
|
return best_idx
|
||||||
|
|
||||||
|
|
||||||
def signal_to_text(Y, codebook, r=6):
|
def signal_to_text(Y, codebook, r=6):
|
||||||
_, n, _, _ = construct_codebook(r, 1)
|
# Reconstruct codebook length (seg_len)
|
||||||
seg_len = n
|
_, seg_len, _, _ = construct_codebook(r, 1)
|
||||||
text = ''
|
text = ''
|
||||||
for i in range(40):
|
for i in range(40):
|
||||||
seg = Y[i * seg_len:(i + 1) * seg_len]
|
seg = Y[i * seg_len:(i + 1) * seg_len]
|
||||||
index = decode_message(seg, codebook)
|
idx = decode_message(seg, codebook)
|
||||||
if index in index_to_char:
|
text += index_to_char.get(idx, '?')
|
||||||
text += index_to_char[index]
|
|
||||||
else:
|
|
||||||
text += '?' # Unknown character
|
|
||||||
return text
|
return text
|
|
@ -1,11 +1,14 @@
|
||||||
|
# encoder.py (unchanged except default r)
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from codebook import construct_codebook
|
from codebook import construct_codebook
|
||||||
from utils import char_to_index, normalize_energy
|
from utils import char_to_index, normalize_energy
|
||||||
|
|
||||||
def text_to_signal(text, r=9, Eb=4):
|
def text_to_signal(text, r=5, Eb=3):
|
||||||
assert len(text) == 40, "Message must be exactly 40 characters."
|
assert len(text) == 40, "Message must be exactly 40 characters."
|
||||||
codebook, n, m, alpha = construct_codebook(r, Eb)
|
codebook, n, m, alpha = construct_codebook(r, Eb)
|
||||||
|
# Map each character to its codeword
|
||||||
msg_indices = [char_to_index[c] for c in text]
|
msg_indices = [char_to_index[c] for c in text]
|
||||||
signal = np.concatenate([codebook[i] for i in msg_indices])
|
signal = np.concatenate([codebook[i] for i in msg_indices])
|
||||||
|
# Enforce the energy constraint
|
||||||
signal = normalize_energy(signal, energy_limit=2000)
|
signal = normalize_energy(signal, energy_limit=2000)
|
||||||
return signal, codebook
|
return signal, codebook
|
54
evaluate_zero_error_ratio.py
Normal file
54
evaluate_zero_error_ratio.py
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
# evaluate_zero_error_ratio.py
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from encoder import text_to_signal
|
||||||
|
from decoder import signal_to_text
|
||||||
|
from channel import channel
|
||||||
|
|
||||||
|
def zero_error_ratio(message: str, r: int, Eb: float, n_runs: int) -> float:
|
||||||
|
"""
|
||||||
|
Runs the encode–channel–decode pipeline n_runs times on the same message,
|
||||||
|
and returns the fraction of runs with 0 character errors.
|
||||||
|
"""
|
||||||
|
zero_count = 0
|
||||||
|
for _ in range(n_runs):
|
||||||
|
# encode
|
||||||
|
x, codebook = text_to_signal(message, r=r, Eb=Eb)
|
||||||
|
# pass through channel
|
||||||
|
y = channel(x)
|
||||||
|
# decode
|
||||||
|
decoded = signal_to_text(y, codebook, r=r)
|
||||||
|
# count errors
|
||||||
|
errors = sum(1 for a, b in zip(message, decoded) if a != b)
|
||||||
|
if errors == 0:
|
||||||
|
zero_count += 1
|
||||||
|
return zero_count / n_runs
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Estimate the zero-error run ratio over n trials."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--message", type=str, required=True,
|
||||||
|
help="The 40-char message to test."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--r", type=int, default=6,
|
||||||
|
help="Codebook parameter r (default: 6)."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--Eb", type=float, default=3.0,
|
||||||
|
help="Energy per bit Eb (default: 3)."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n", type=int, default=1000,
|
||||||
|
help="Number of runs (default: 1000)."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if len(args.message) != 40:
|
||||||
|
raise ValueError("Message must be exactly 40 characters long.")
|
||||||
|
|
||||||
|
ratio = zero_error_ratio(args.message, args.r, args.Eb, args.n)
|
||||||
|
print(f"Zero-error runs: {ratio:.3%} ({ratio:.4f} of {args.n})")
|
|
@ -3,10 +3,10 @@ from decoder import signal_to_text
|
||||||
from channel import channel
|
from channel import channel
|
||||||
|
|
||||||
def test_local():
|
def test_local():
|
||||||
message = "This is the end of the PDC course. Good "
|
message = "HelloWorld123 ThisIsATestMessage12345678"
|
||||||
x, codebook = text_to_signal(message, r=9, Eb=5)
|
x, codebook = text_to_signal(message, r=6, Eb=3)
|
||||||
y = channel(x)
|
y = channel(x)
|
||||||
decoded = signal_to_text(y, codebook, r=9)
|
decoded = signal_to_text(y, codebook, r=6)
|
||||||
|
|
||||||
print(f"Original: {message}")
|
print(f"Original: {message}")
|
||||||
print(f"Decoded : {decoded}")
|
print(f"Decoded : {decoded}")
|
||||||
|
@ -14,4 +14,4 @@ def test_local():
|
||||||
print(f"Character errors: {errors}/40")
|
print(f"Character errors: {errors}/40")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_local()
|
test_local()
|
||||||
|
|
|
@ -5,7 +5,7 @@ from decoder import signal_to_text
|
||||||
|
|
||||||
def test_server():
|
def test_server():
|
||||||
message = "HelloWorld123 ThisIsATestMessage12345678"
|
message = "HelloWorld123 ThisIsATestMessage12345678"
|
||||||
x, codebook = text_to_signal(message, r=6, Eb=1)
|
x, codebook = text_to_signal(message, r=6, Eb=3)
|
||||||
np.savetxt("input.txt", x, fmt="%.10f")
|
np.savetxt("input.txt", x, fmt="%.10f")
|
||||||
|
|
||||||
subprocess.run([
|
subprocess.run([
|
||||||
|
|
Loading…
Add table
Reference in a new issue