diff --git a/decoder.py b/decoder.py index 75a0750..dae00ce 100644 --- a/decoder.py +++ b/decoder.py @@ -5,7 +5,7 @@ from encoder import ALPHABET, G def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str: """ - Decode received samples by maximum correlation score + Decode received samples by maximum correlation score (state unknown, per-block). """ n = C.shape[1] half = n // 2 @@ -15,7 +15,7 @@ def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str: sqrtG = np.sqrt(G) recovered = [] for k in range(num): - Yb = Y[k*n:(k+1)*n] + Yb = Y[k * n:(k + 1) * n] Ye, Yo = Yb[0::2], Yb[1::2] s1 = sqrtG * (Ye @ C1.T) + (Yo @ C2.T) s2 = (Ye @ C1.T) + sqrtG * (Yo @ C2.T) @@ -25,8 +25,37 @@ def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str: return "".join(recovered) -def count_errors(orig: str, est: str): +def decode_blocks_with_state(Y: np.ndarray, C: np.ndarray) -> (str, int): """ - List mismatches between orig and est + 1) Estimate the single channel state s∈{1,2} by comparing total energy + on even vs odd positions across the entire Y. + 2) Decode **all** blocks using that one state’s scoring rule. + + Returns (decoded_string, estimated_state). """ - return [(i, o, e) for i, (o, e) in enumerate(zip(orig, est)) if o != e] \ No newline at end of file + n = C.shape[1] + half = n // 2 + num = Y.size // n + C1 = C[:, :half] + C2 = C[:, half:] + sqrtG = np.sqrt(G) + + # 1) state estimate from full-length energies + Ye_all = Y[0::2] + Yo_all = Y[1::2] + E_even = np.sum(Ye_all ** 2) + E_odd = np.sum(Yo_all ** 2) + s_est = 1 if E_even > E_odd else 2 + + recovered = [] + for k in range(num): + Yb = Y[k * n:(k + 1) * n] + Ye, Yo = Yb[0::2], Yb[1::2] + if s_est == 1: + score = sqrtG * (Ye @ C1.T) + (Yo @ C2.T) + else: + score = (Ye @ C1.T) + sqrtG * (Yo @ C2.T) + best = int(np.argmax(score)) + recovered.append(ALPHABET[best]) + + return "".join(recovered), s_est diff --git a/main.py b/main.py index 5d63376..1e90795 100644 --- a/main.py +++ b/main.py @@ -16,127 +16,104 @@ INPUT_FILE = None OUTPUT_FILE = None def transmit(msg, C): - ''' - Transmitter: encodes the message into real-valued samples using the codebook C. - ''' return encoder.encode_message(msg, C) def receive_local(c): - ''' - Sends the samples through the local channel simulation. - ''' return channel.channel(c) def receive_server(c, hostname, port): - ''' - Sends the samples to the remote server via client.py. If INPUT_FILE and/or - OUTPUT_FILE are set, uses those filenames (and preserves them); otherwise uses temporary files. - ''' global INPUT_FILE, OUTPUT_FILE - # Determine input file path + # write or temp-file for input if INPUT_FILE: - in_name = INPUT_FILE + in_name, delete_in = INPUT_FILE, False np.savetxt(in_name, c) - delete_in = False else: - with tempfile.NamedTemporaryFile(suffix='.txt', delete=False) as in_f: - np.savetxt(in_f.name, c) - in_name = in_f.name - delete_in = True - # Determine output file path + tf = tempfile.NamedTemporaryFile(suffix='.txt', delete=False) + np.savetxt(tf.name, c) + in_name, delete_in = tf.name, True + tf.close() + # write or temp-file for output if OUTPUT_FILE: - out_name = OUTPUT_FILE - delete_out = False + out_name, delete_out = OUTPUT_FILE, False else: - out_fd, out_name = tempfile.mkstemp(suffix='.txt') - os.close(out_fd) + fd, out_name = tempfile.mkstemp(suffix='.txt') + os.close(fd) delete_out = True - # Invoke client.py + + # call client.py cmd = [ sys.executable, str(pathlib.Path(__file__).parent / 'client.py'), - '--input_file', in_name, + '--input_file', in_name, '--output_file', out_name, '--srv_hostname', hostname, - '--srv_port', str(port) + '--srv_port', str(port) ] try: subprocess.run(cmd, check=True) Y = np.loadtxt(out_name) finally: - if delete_in and os.path.exists(in_name): - os.remove(in_name) - if delete_out and os.path.exists(out_name): - os.remove(out_name) + if delete_in and os.path.exists(in_name): os.remove(in_name) + if delete_out and os.path.exists(out_name): os.remove(out_name) + return Y def receive(c, mode, hostname, port): - ''' - Wrapper to choose local or server channel. - ''' if mode == 'local': return receive_local(c) elif mode == 'server': return receive_server(c, hostname, port) else: - raise ValueError("Mode must be 'local' or 'server'") + raise ValueError("mode must be 'local' or 'server'") def test_performance(msg, num_trials, mode, hostname, port): - ''' - Runs num_trials transmissions of msg through the specified channel and reports - per-trial decoded messages and error counts, plus overall accuracy. - ''' if len(msg) != 40: raise ValueError('Message must be exactly 40 characters.') - # Build codebook for 64 symbols, 40 blocks + + # build the same codebook C = encoder.make_codebook(r=5, num_blocks=40) successes = 0 - - print(f"Original message: {msg}") + print(f"Original message: {msg!r}") print(f"Running {num_trials} trials over '{mode}' channel...\n") for i in range(num_trials): - # Transmit + # TX → channel → RX c = transmit(msg, C) - # Channel Y = receive(c, mode, hostname, port) - # Decode - est = decoder.decode_blocks(Y, C) - # Count character errors - errors = sum(1 for a, b in zip(est, msg) if a != b) - # Tally success if no errors + + # decode with state‐estimation + est, s_est = decoder.decode_blocks_with_state(Y, C) + + # count char errors + errors = sum(1 for a,b in zip(est, msg) if a!=b) if errors == 0: successes += 1 - # Print trial result - print(f"Trial {i+1:3d}: Decoded: '{est}' | Errors: {errors}") - pct = successes / num_trials * 100 + print(f"Trial {i+1:3d}: state={s_est} decoded={est!r} errors={errors}") + + pct = 100 * successes / num_trials print("\n=== Summary ===") - print(f"Total trials: {num_trials}") - print(f"Perfect decodings: {successes}") - print(f"Overall accuracy: {pct:.2f}%") + print(f" Total trials: {num_trials}") + print(f" Perfect decodings: {successes}") + print(f" Overall accuracy: {pct:.2f}%") def parse_args(): - parser = argparse.ArgumentParser(description='Test communication system performance.') - parser.add_argument('--message', '-m', type=str, required=True, - help='40-character message to send.') - parser.add_argument('--trials', '-n', type=int, default=200, - help='Number of trials.') - parser.add_argument('--mode', choices=['local','server'], default='local', - help="Channel mode: 'local' or 'server'.") - parser.add_argument('--hostname', type=str, default='iscsrv72.epfl.ch', - help='Server hostname for server mode.') - parser.add_argument('--port', type=int, default=80, - help='Server port for server mode.') - parser.add_argument('--input_file', '-i', type=str, default=None, - help='Path to write server input samples (input.txt).') - parser.add_argument('--output_file', '-o', type=str, default=None, - help='Path to write server output samples (output.txt).') - return parser.parse_args() + p = argparse.ArgumentParser(description='Test comms system') + p.add_argument('--message','-m', required=True, help='40-char message') + p.add_argument('--trials','-n', type=int, default=200, help='Number of trials') + p.add_argument('--mode', choices=['local','server'], default='local') + p.add_argument('--hostname', default='iscsrv72.epfl.ch') + p.add_argument('--port', type=int, default=80) + p.add_argument('--input_file','-i', default=None, + help='(server mode) where to write input.txt') + p.add_argument('--output_file','-o',default=None, + help='(server mode) where to write output.txt') + return p.parse_args() -if __name__ == '__main__': +if __name__=='__main__': args = parse_args() - INPUT_FILE = args.input_file + INPUT_FILE = args.input_file OUTPUT_FILE = args.output_file - test_performance(args.message, args.trials, args.mode, args.hostname, args.port) + test_performance(args.message, args.trials, + args.mode, args.hostname, args.port)