From 545d4d7768ffc2320c1f4537024219cd2c668a95 Mon Sep 17 00:00:00 2001 From: appellet Date: Sun, 25 May 2025 14:41:39 +0200 Subject: [PATCH] feat: improve results basing on theory --- decoder.py | 40 ++++++++++++++++---------- encoder.py | 5 +++- evaluate_zero_error_ratio.py | 54 ++++++++++++++++++++++++++++++++++++ test_local.py | 8 +++--- test_server.py | 2 +- 5 files changed, 89 insertions(+), 20 deletions(-) create mode 100644 evaluate_zero_error_ratio.py diff --git a/decoder.py b/decoder.py index 409c6fe..3a82c04 100644 --- a/decoder.py +++ b/decoder.py @@ -1,28 +1,40 @@ - -# decoder.py import numpy as np +from numpy import logaddexp from utils import index_to_char from codebook import construct_codebook def decode_message(Y, codebook): + """ + ML decoding for unknown channel state s: + p(Y|i) = 0.5*p(Y|i,s=1) + 0.5*p(Y|i,s=2) + We use log-sum-exp to combine both branch metrics. + """ + G = 10 Y1, Y2 = Y[::2], Y[1::2] - scores = [] + best_idx = None + best_metric = -np.inf for i, c in enumerate(codebook): + # Only consider indices that map to characters + if i not in index_to_char: + continue c1, c2 = c[::2], c[1::2] - s1 = np.sqrt(10) * np.dot(Y1, c1) + np.dot(Y2, c2) - s2 = np.dot(Y1, c1) + np.sqrt(10) * np.dot(Y2, c2) - scores.append(max(s1, s2)) - return np.argmax(scores) + # Branch metrics (up to additive constants) + s1 = np.sqrt(G) * np.dot(Y1, c1) + np.dot(Y2, c2) + s2 = np.dot(Y1, c1) + np.sqrt(G) * np.dot(Y2, c2) + # Combine via log-sum-exp + metric = logaddexp(s1, s2) + if metric > best_metric: + best_metric = metric + best_idx = i + return best_idx + def signal_to_text(Y, codebook, r=6): - _, n, _, _ = construct_codebook(r, 1) - seg_len = n + # Reconstruct codebook length (seg_len) + _, seg_len, _, _ = construct_codebook(r, 1) text = '' for i in range(40): seg = Y[i * seg_len:(i + 1) * seg_len] - index = decode_message(seg, codebook) - if index in index_to_char: - text += index_to_char[index] - else: - text += '?' # Unknown character + idx = decode_message(seg, codebook) + text += index_to_char.get(idx, '?') return text \ No newline at end of file diff --git a/encoder.py b/encoder.py index 3e1674f..92e81a8 100644 --- a/encoder.py +++ b/encoder.py @@ -1,11 +1,14 @@ +# encoder.py (unchanged except default r) import numpy as np from codebook import construct_codebook from utils import char_to_index, normalize_energy -def text_to_signal(text, r=9, Eb=4): +def text_to_signal(text, r=5, Eb=3): assert len(text) == 40, "Message must be exactly 40 characters." codebook, n, m, alpha = construct_codebook(r, Eb) + # Map each character to its codeword msg_indices = [char_to_index[c] for c in text] signal = np.concatenate([codebook[i] for i in msg_indices]) + # Enforce the energy constraint signal = normalize_energy(signal, energy_limit=2000) return signal, codebook \ No newline at end of file diff --git a/evaluate_zero_error_ratio.py b/evaluate_zero_error_ratio.py new file mode 100644 index 0000000..986c74f --- /dev/null +++ b/evaluate_zero_error_ratio.py @@ -0,0 +1,54 @@ +# evaluate_zero_error_ratio.py + +import argparse +import numpy as np +from encoder import text_to_signal +from decoder import signal_to_text +from channel import channel + +def zero_error_ratio(message: str, r: int, Eb: float, n_runs: int) -> float: + """ + Runs the encode–channel–decode pipeline n_runs times on the same message, + and returns the fraction of runs with 0 character errors. + """ + zero_count = 0 + for _ in range(n_runs): + # encode + x, codebook = text_to_signal(message, r=r, Eb=Eb) + # pass through channel + y = channel(x) + # decode + decoded = signal_to_text(y, codebook, r=r) + # count errors + errors = sum(1 for a, b in zip(message, decoded) if a != b) + if errors == 0: + zero_count += 1 + return zero_count / n_runs + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Estimate the zero-error run ratio over n trials." + ) + parser.add_argument( + "--message", type=str, required=True, + help="The 40-char message to test." + ) + parser.add_argument( + "--r", type=int, default=6, + help="Codebook parameter r (default: 6)." + ) + parser.add_argument( + "--Eb", type=float, default=3.0, + help="Energy per bit Eb (default: 3)." + ) + parser.add_argument( + "--n", type=int, default=1000, + help="Number of runs (default: 1000)." + ) + args = parser.parse_args() + + if len(args.message) != 40: + raise ValueError("Message must be exactly 40 characters long.") + + ratio = zero_error_ratio(args.message, args.r, args.Eb, args.n) + print(f"Zero-error runs: {ratio:.3%} ({ratio:.4f} of {args.n})") diff --git a/test_local.py b/test_local.py index 98da1f2..23a6901 100644 --- a/test_local.py +++ b/test_local.py @@ -3,10 +3,10 @@ from decoder import signal_to_text from channel import channel def test_local(): - message = "This is the end of the PDC course. Good " - x, codebook = text_to_signal(message, r=9, Eb=5) + message = "HelloWorld123 ThisIsATestMessage12345678" + x, codebook = text_to_signal(message, r=6, Eb=3) y = channel(x) - decoded = signal_to_text(y, codebook, r=9) + decoded = signal_to_text(y, codebook, r=6) print(f"Original: {message}") print(f"Decoded : {decoded}") @@ -14,4 +14,4 @@ def test_local(): print(f"Character errors: {errors}/40") if __name__ == "__main__": - test_local() \ No newline at end of file + test_local() diff --git a/test_server.py b/test_server.py index 60b8212..bb4190f 100644 --- a/test_server.py +++ b/test_server.py @@ -5,7 +5,7 @@ from decoder import signal_to_text def test_server(): message = "HelloWorld123 ThisIsATestMessage12345678" - x, codebook = text_to_signal(message, r=6, Eb=1) + x, codebook = text_to_signal(message, r=6, Eb=3) np.savetxt("input.txt", x, fmt="%.10f") subprocess.run([