
// MachineLearningDistributionDemo.java
import java.util.ArrayList;
import java.util.List;

// ----- DataPartitioner -----
class DataPartitioner {
    public void split() {
        System.out.println("[DataPartitioner] Splitting dataset into mini-batches...");
    }
}

// ----- Worker -----
class Worker {
    public void compute() {
        System.out.println("[Worker] Computing gradients on local data...");
    }

    public void update() {
        System.out.println("[Worker] Updating local model copy...");
    }

    public void receiveModel() {
        System.out.println("[Worker] Receiving model parameters from Central Server...");
    }

    public void sendGradients() {
        System.out.println("[Worker] Sending gradients back to Central Server...");
    }
}

// ----- CentralServer -----
class CentralServer {
    public void distribute() {
        System.out.println("[CentralServer] Distributing model parameters to workers...");
    }

    public void aggregate() {
        System.out.println("[CentralServer] Aggregating gradients from workers...");
    }

    public void updateGlobalModel() {
        System.out.println("[CentralServer] Updating global model parameters...");
    }
}

// ----- Main Program -----
public class DistStrategy {
    public static void main(String[] args) {
        CentralServer central = new CentralServer();
        DataPartitioner partitioner = new DataPartitioner();
        List<Worker> workers = new ArrayList<>();

        // Create 3 workers
        for (int i = 0; i < 3; i++) {
            workers.add(new Worker());
        }

        System.out.println("--- Simulation of Distributed Training ---\n");

        // Step 1: Partition data
        partitioner.split();

        // Step 2: Distribute model to workers
        central.distribute();
        for (Worker w : workers) {
            w.receiveModel();
        }

        // Step 3: Workers compute gradients
        for (Worker w : workers) {
            w.compute();
            w.sendGradients();
        }

        // Step 4: Central aggregates and updates global model
        central.aggregate();
        central.updateGlobalModel();

        // Step 5: Distribute updated model again
        central.distribute();

        System.out.println("\n--- End of Simulation ---");
    }
}
