feat: improve decoder.py

This commit is contained in:
appellet 2025-05-27 12:04:14 +02:00
parent 177bc566cf
commit 8b876c0d06
2 changed files with 83 additions and 77 deletions

View file

@ -5,7 +5,7 @@ from encoder import ALPHABET, G
def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str: def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str:
""" """
Decode received samples by maximum correlation score Decode received samples by maximum correlation score (state unknown, per-block).
""" """
n = C.shape[1] n = C.shape[1]
half = n // 2 half = n // 2
@ -15,7 +15,7 @@ def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str:
sqrtG = np.sqrt(G) sqrtG = np.sqrt(G)
recovered = [] recovered = []
for k in range(num): for k in range(num):
Yb = Y[k*n:(k+1)*n] Yb = Y[k * n:(k + 1) * n]
Ye, Yo = Yb[0::2], Yb[1::2] Ye, Yo = Yb[0::2], Yb[1::2]
s1 = sqrtG * (Ye @ C1.T) + (Yo @ C2.T) s1 = sqrtG * (Ye @ C1.T) + (Yo @ C2.T)
s2 = (Ye @ C1.T) + sqrtG * (Yo @ C2.T) s2 = (Ye @ C1.T) + sqrtG * (Yo @ C2.T)
@ -25,8 +25,37 @@ def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str:
return "".join(recovered) return "".join(recovered)
def count_errors(orig: str, est: str): def decode_blocks_with_state(Y: np.ndarray, C: np.ndarray) -> (str, int):
""" """
List mismatches between orig and est 1) Estimate the single channel state s{1,2} by comparing total energy
on even vs odd positions across the entire Y.
2) Decode **all** blocks using that one states scoring rule.
Returns (decoded_string, estimated_state).
""" """
return [(i, o, e) for i, (o, e) in enumerate(zip(orig, est)) if o != e] n = C.shape[1]
half = n // 2
num = Y.size // n
C1 = C[:, :half]
C2 = C[:, half:]
sqrtG = np.sqrt(G)
# 1) state estimate from full-length energies
Ye_all = Y[0::2]
Yo_all = Y[1::2]
E_even = np.sum(Ye_all ** 2)
E_odd = np.sum(Yo_all ** 2)
s_est = 1 if E_even > E_odd else 2
recovered = []
for k in range(num):
Yb = Y[k * n:(k + 1) * n]
Ye, Yo = Yb[0::2], Yb[1::2]
if s_est == 1:
score = sqrtG * (Ye @ C1.T) + (Yo @ C2.T)
else:
score = (Ye @ C1.T) + sqrtG * (Yo @ C2.T)
best = int(np.argmax(score))
recovered.append(ALPHABET[best])
return "".join(recovered), s_est

121
main.py
View file

@ -16,127 +16,104 @@ INPUT_FILE = None
OUTPUT_FILE = None OUTPUT_FILE = None
def transmit(msg, C): def transmit(msg, C):
'''
Transmitter: encodes the message into real-valued samples using the codebook C.
'''
return encoder.encode_message(msg, C) return encoder.encode_message(msg, C)
def receive_local(c): def receive_local(c):
'''
Sends the samples through the local channel simulation.
'''
return channel.channel(c) return channel.channel(c)
def receive_server(c, hostname, port): def receive_server(c, hostname, port):
'''
Sends the samples to the remote server via client.py. If INPUT_FILE and/or
OUTPUT_FILE are set, uses those filenames (and preserves them); otherwise uses temporary files.
'''
global INPUT_FILE, OUTPUT_FILE global INPUT_FILE, OUTPUT_FILE
# Determine input file path # write or temp-file for input
if INPUT_FILE: if INPUT_FILE:
in_name = INPUT_FILE in_name, delete_in = INPUT_FILE, False
np.savetxt(in_name, c) np.savetxt(in_name, c)
delete_in = False
else: else:
with tempfile.NamedTemporaryFile(suffix='.txt', delete=False) as in_f: tf = tempfile.NamedTemporaryFile(suffix='.txt', delete=False)
np.savetxt(in_f.name, c) np.savetxt(tf.name, c)
in_name = in_f.name in_name, delete_in = tf.name, True
delete_in = True tf.close()
# Determine output file path # write or temp-file for output
if OUTPUT_FILE: if OUTPUT_FILE:
out_name = OUTPUT_FILE out_name, delete_out = OUTPUT_FILE, False
delete_out = False
else: else:
out_fd, out_name = tempfile.mkstemp(suffix='.txt') fd, out_name = tempfile.mkstemp(suffix='.txt')
os.close(out_fd) os.close(fd)
delete_out = True delete_out = True
# Invoke client.py
# call client.py
cmd = [ cmd = [
sys.executable, sys.executable,
str(pathlib.Path(__file__).parent / 'client.py'), str(pathlib.Path(__file__).parent / 'client.py'),
'--input_file', in_name, '--input_file', in_name,
'--output_file', out_name, '--output_file', out_name,
'--srv_hostname', hostname, '--srv_hostname', hostname,
'--srv_port', str(port) '--srv_port', str(port)
] ]
try: try:
subprocess.run(cmd, check=True) subprocess.run(cmd, check=True)
Y = np.loadtxt(out_name) Y = np.loadtxt(out_name)
finally: finally:
if delete_in and os.path.exists(in_name): if delete_in and os.path.exists(in_name): os.remove(in_name)
os.remove(in_name) if delete_out and os.path.exists(out_name): os.remove(out_name)
if delete_out and os.path.exists(out_name):
os.remove(out_name)
return Y return Y
def receive(c, mode, hostname, port): def receive(c, mode, hostname, port):
'''
Wrapper to choose local or server channel.
'''
if mode == 'local': if mode == 'local':
return receive_local(c) return receive_local(c)
elif mode == 'server': elif mode == 'server':
return receive_server(c, hostname, port) return receive_server(c, hostname, port)
else: else:
raise ValueError("Mode must be 'local' or 'server'") raise ValueError("mode must be 'local' or 'server'")
def test_performance(msg, num_trials, mode, hostname, port): def test_performance(msg, num_trials, mode, hostname, port):
'''
Runs num_trials transmissions of msg through the specified channel and reports
per-trial decoded messages and error counts, plus overall accuracy.
'''
if len(msg) != 40: if len(msg) != 40:
raise ValueError('Message must be exactly 40 characters.') raise ValueError('Message must be exactly 40 characters.')
# Build codebook for 64 symbols, 40 blocks
# build the same codebook
C = encoder.make_codebook(r=5, num_blocks=40) C = encoder.make_codebook(r=5, num_blocks=40)
successes = 0 successes = 0
print(f"Original message: {msg!r}")
print(f"Original message: {msg}")
print(f"Running {num_trials} trials over '{mode}' channel...\n") print(f"Running {num_trials} trials over '{mode}' channel...\n")
for i in range(num_trials): for i in range(num_trials):
# Transmit # TX → channel → RX
c = transmit(msg, C) c = transmit(msg, C)
# Channel
Y = receive(c, mode, hostname, port) Y = receive(c, mode, hostname, port)
# Decode
est = decoder.decode_blocks(Y, C) # decode with stateestimation
# Count character errors est, s_est = decoder.decode_blocks_with_state(Y, C)
errors = sum(1 for a, b in zip(est, msg) if a != b)
# Tally success if no errors # count char errors
errors = sum(1 for a,b in zip(est, msg) if a!=b)
if errors == 0: if errors == 0:
successes += 1 successes += 1
# Print trial result
print(f"Trial {i+1:3d}: Decoded: '{est}' | Errors: {errors}")
pct = successes / num_trials * 100 print(f"Trial {i+1:3d}: state={s_est} decoded={est!r} errors={errors}")
pct = 100 * successes / num_trials
print("\n=== Summary ===") print("\n=== Summary ===")
print(f"Total trials: {num_trials}") print(f" Total trials: {num_trials}")
print(f"Perfect decodings: {successes}") print(f" Perfect decodings: {successes}")
print(f"Overall accuracy: {pct:.2f}%") print(f" Overall accuracy: {pct:.2f}%")
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Test communication system performance.') p = argparse.ArgumentParser(description='Test comms system')
parser.add_argument('--message', '-m', type=str, required=True, p.add_argument('--message','-m', required=True, help='40-char message')
help='40-character message to send.') p.add_argument('--trials','-n', type=int, default=200, help='Number of trials')
parser.add_argument('--trials', '-n', type=int, default=200, p.add_argument('--mode', choices=['local','server'], default='local')
help='Number of trials.') p.add_argument('--hostname', default='iscsrv72.epfl.ch')
parser.add_argument('--mode', choices=['local','server'], default='local', p.add_argument('--port', type=int, default=80)
help="Channel mode: 'local' or 'server'.") p.add_argument('--input_file','-i', default=None,
parser.add_argument('--hostname', type=str, default='iscsrv72.epfl.ch', help='(server mode) where to write input.txt')
help='Server hostname for server mode.') p.add_argument('--output_file','-o',default=None,
parser.add_argument('--port', type=int, default=80, help='(server mode) where to write output.txt')
help='Server port for server mode.') return p.parse_args()
parser.add_argument('--input_file', '-i', type=str, default=None,
help='Path to write server input samples (input.txt).')
parser.add_argument('--output_file', '-o', type=str, default=None,
help='Path to write server output samples (output.txt).')
return parser.parse_args()
if __name__ == '__main__': if __name__=='__main__':
args = parse_args() args = parse_args()
INPUT_FILE = args.input_file INPUT_FILE = args.input_file
OUTPUT_FILE = args.output_file OUTPUT_FILE = args.output_file
test_performance(args.message, args.trials, args.mode, args.hostname, args.port) test_performance(args.message, args.trials,
args.mode, args.hostname, args.port)