import java.io.*;
import java.util.HashMap;
import java.util.Map;

// ----- ModelState -----
class ModelState {
    Map<String, Double> parameters = new HashMap<>();
    int epoch = 0;
    double metric = 0.0;
}

// ----- Storage Interface -----
interface Storage {
    void save(ModelState data, String path) throws IOException;

    void load(ModelState data, String path) throws IOException;
}

// ----- LocalStorage -----
class LocalStorage implements Storage {
    @Override
    public void save(ModelState data, String path) throws IOException {
        try (PrintWriter out = new PrintWriter(new FileWriter(path))) {
            out.println(data.epoch + " " + data.metric);
            System.out.println("Saved checkpoint to " + path);
        }
    }

    @Override
    public void load(ModelState data, String path) throws IOException {
        File file = new File(path);
        if (!file.exists())
            return;
        try (BufferedReader in = new BufferedReader(new FileReader(path))) {
            String[] parts = in.readLine().split(" ");
            if (parts.length >= 2) {
                data.epoch = Integer.parseInt(parts[0]);
                data.metric = Double.parseDouble(parts[1]);
                System.out.println("Loaded checkpoint from " + path);
            }
        }
    }
}

// ----- CheckpointManager -----
class CheckpointManager {
    private final Storage storage;
    private final String path;
    private double bestMetric = Double.NEGATIVE_INFINITY;

    public CheckpointManager(Storage storage, String path) {
        this.storage = storage;
        this.path = path;
    }

    public void saveCheckpoint(ModelState state) throws IOException {
        if (state.metric > bestMetric) {
            bestMetric = state.metric;
            storage.save(state, path);
        }
    }

    public void loadCheckpoint(ModelState state) throws IOException {
        storage.load(state, path);
    }
}

// ----- Trainer -----
class Trainer {
    private final ModelState state = new ModelState();
    private final CheckpointManager manager;

    public Trainer(CheckpointManager manager) {
        this.manager = manager;
    }

    public void train(int epochs) throws IOException {
        manager.loadCheckpoint(state); // resume if possible
        for (int e = state.epoch + 1; e <= epochs; e++) {
            state.epoch = e;
            state.metric = 0.8 + 0.05 * e; // pretend it improves
            System.out.printf("Epoch %d | metric = %.2f%n", e, state.metric);
            manager.saveCheckpoint(state);
        }
    }
}

// ----- Main -----
public class Checkpoints {
    public static void main(String[] args) {
        try {
            // NOTE: Delete checkpoint.txt before each new run to reset
            Storage storage = new LocalStorage();
            CheckpointManager manager = new CheckpointManager(storage, "checkpoint.txt");
            Trainer trainer = new Trainer(manager);
            trainer.train(5);
        } catch (IOException e) {
            System.err.println("Error: " + e.getMessage());
        }
    }
}
