/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.samza.operators.impl;

import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import org.apache.samza.Partition;
import org.apache.samza.application.descriptors.StreamApplicationDescriptorImpl;
import org.apache.samza.config.Config;
import org.apache.samza.config.JobConfig;
import org.apache.samza.config.MapConfig;
import org.apache.samza.config.StreamConfig;
import org.apache.samza.container.TaskName;
import org.apache.samza.context.Context;
import org.apache.samza.context.MockContext;
import org.apache.samza.context.TaskContextImpl;
import org.apache.samza.system.descriptors.GenericInputDescriptor;
import org.apache.samza.system.descriptors.GenericOutputDescriptor;
import org.apache.samza.system.descriptors.GenericSystemDescriptor;
import org.apache.samza.job.model.ContainerModel;
import org.apache.samza.job.model.JobModel;
import org.apache.samza.job.model.TaskModel;
import org.apache.samza.metrics.MetricsRegistryMap;
import org.apache.samza.operators.KV;
import org.apache.samza.operators.MessageStream;
import org.apache.samza.operators.OutputStream;
import org.apache.samza.operators.functions.ClosableFunction;
import org.apache.samza.operators.functions.FilterFunction;
import org.apache.samza.operators.functions.InitableFunction;
import org.apache.samza.operators.functions.JoinFunction;
import org.apache.samza.operators.functions.MapFunction;
import org.apache.samza.operators.spec.OperatorSpec.OpCode;
import org.apache.samza.serializers.IntegerSerde;
import org.apache.samza.serializers.KVSerde;
import org.apache.samza.serializers.Serde;
import org.apache.samza.serializers.StringSerde;
import org.apache.samza.storage.kv.KeyValueStore;
import org.apache.samza.system.IncomingMessageEnvelope;
import org.apache.samza.system.SystemStream;
import org.apache.samza.system.SystemStreamPartition;
import org.apache.samza.task.MessageCollector;
import org.apache.samza.task.TaskCoordinator;
import org.apache.samza.testUtils.StreamTestUtils;
import org.apache.samza.util.Clock;
import org.apache.samza.util.SystemClock;
import org.apache.samza.util.TimestampedValue;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import java.io.Serializable;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Function;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class TestOperatorImplGraph {
  private Context context;

  @Before
  public void setup() {
    this.context = new MockContext();
    // individual tests can override this config if necessary
    when(this.context.getJobContext().getConfig()).thenReturn(mock(Config.class));
    TaskModel taskModel = mock(TaskModel.class);
    when(taskModel.getTaskName()).thenReturn(new TaskName("task 0"));
    when(this.context.getTaskContext().getTaskModel()).thenReturn(taskModel);
    when(this.context.getTaskContext().getTaskMetricsRegistry()).thenReturn(new MetricsRegistryMap());
    when(this.context.getContainerContext().getContainerMetricsRegistry()).thenReturn(new MetricsRegistryMap());
  }

  @After
  public void tearDown() {
    BaseTestFunction.reset();
  }

  @Test
  public void testEmptyChain() {
    StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> { }, mock(Config.class));
    OperatorImplGraph opGraph = new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), context, mock(Clock.class));
    assertEquals(0, opGraph.getAllInputOperators().size());
  }

  @Test
  public void testLinearChain() {
    String inputStreamId = "input";
    String inputSystem = "input-system";
    String inputPhysicalName = "input-stream";
    String outputStreamId = "output";
    String outputSystem = "output-system";
    String outputPhysicalName = "output-stream";
    String intermediateSystem = "intermediate-system";

    HashMap<String, String> configs = new HashMap<>();
    configs.put(JobConfig.JOB_NAME, "jobName");
    configs.put(JobConfig.JOB_ID, "jobId");
    configs.put(JobConfig.JOB_DEFAULT_SYSTEM, intermediateSystem);
    StreamTestUtils.addStreamConfigs(configs, inputStreamId, inputSystem, inputPhysicalName);
    StreamTestUtils.addStreamConfigs(configs, outputStreamId, outputSystem, outputPhysicalName);
    Config config = new MapConfig(configs);
    when(this.context.getJobContext().getConfig()).thenReturn(config);

    StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> {
      GenericSystemDescriptor sd = new GenericSystemDescriptor(inputSystem, "mockFactoryClass");
      GenericInputDescriptor inputDescriptor = sd.getInputDescriptor(inputStreamId, mock(Serde.class));
      GenericOutputDescriptor outputDescriptor = sd.getOutputDescriptor(outputStreamId, mock(Serde.class));
      MessageStream<Object> inputStream = appDesc.getInputStream(inputDescriptor);
      OutputStream<Object> outputStream = appDesc.getOutputStream(outputDescriptor);

      inputStream
          .filter(mock(FilterFunction.class))
          .map(mock(MapFunction.class))
          .sendTo(outputStream);
    }, config);

    OperatorImplGraph opImplGraph =
        new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), this.context, mock(Clock.class));

    InputOperatorImpl inputOpImpl = opImplGraph.getInputOperator(new SystemStream(inputSystem, inputPhysicalName));
    assertEquals(1, inputOpImpl.registeredOperators.size());

    OperatorImpl filterOpImpl = (FlatmapOperatorImpl) inputOpImpl.registeredOperators.iterator().next();
    assertEquals(1, filterOpImpl.registeredOperators.size());
    assertEquals(OpCode.FILTER, filterOpImpl.getOperatorSpec().getOpCode());

    OperatorImpl mapOpImpl = (FlatmapOperatorImpl) filterOpImpl.registeredOperators.iterator().next();
    assertEquals(1, mapOpImpl.registeredOperators.size());
    assertEquals(OpCode.MAP, mapOpImpl.getOperatorSpec().getOpCode());

    OperatorImpl sendToOpImpl = (OutputOperatorImpl) mapOpImpl.registeredOperators.iterator().next();
    assertEquals(0, sendToOpImpl.registeredOperators.size());
    assertEquals(OpCode.SEND_TO, sendToOpImpl.getOperatorSpec().getOpCode());
  }

  @Test
  public void testPartitionByChain() {
    String inputStreamId = "input";
    String inputSystem = "input-system";
    String inputPhysicalName = "input-stream";
    String outputStreamId = "output";
    String outputSystem = "output-system";
    String outputPhysicalName = "output-stream";
    String intermediateStreamId = "jobName-jobId-partition_by-p1";
    String intermediateSystem = "intermediate-system";

    HashMap<String, String> configs = new HashMap<>();
    configs.put(JobConfig.JOB_NAME, "jobName");
    configs.put(JobConfig.JOB_ID, "jobId");
    configs.put(JobConfig.JOB_DEFAULT_SYSTEM, intermediateSystem);
    StreamTestUtils.addStreamConfigs(configs, inputStreamId, inputSystem, inputPhysicalName);
    StreamTestUtils.addStreamConfigs(configs, outputStreamId, outputSystem, outputPhysicalName);
    Config config = new MapConfig(configs);
    when(this.context.getJobContext().getConfig()).thenReturn(config);

    StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> {
      GenericSystemDescriptor isd = new GenericSystemDescriptor(inputSystem, "mockFactoryClass");
      GenericSystemDescriptor osd = new GenericSystemDescriptor(outputSystem, "mockFactoryClass");
      GenericInputDescriptor inputDescriptor = isd.getInputDescriptor(inputStreamId, mock(Serde.class));
      GenericOutputDescriptor outputDescriptor = osd.getOutputDescriptor(outputStreamId,
          KVSerde.of(mock(IntegerSerde.class), mock(StringSerde.class)));
      MessageStream<Object> inputStream = appDesc.getInputStream(inputDescriptor);
      OutputStream<KV<Integer, String>> outputStream = appDesc.getOutputStream(outputDescriptor);

      inputStream
          .partitionBy(Object::hashCode, Object::toString,
              KVSerde.of(mock(IntegerSerde.class), mock(StringSerde.class)), "p1")
          .sendTo(outputStream);
    }, config);

    JobModel jobModel = mock(JobModel.class);
    ContainerModel containerModel = mock(ContainerModel.class);
    TaskModel taskModel = mock(TaskModel.class);
    when(jobModel.getContainers()).thenReturn(Collections.singletonMap("0", containerModel));
    when(containerModel.getTasks()).thenReturn(Collections.singletonMap(new TaskName("task 0"), taskModel));
    when(taskModel.getSystemStreamPartitions()).thenReturn(Collections.emptySet());
    when(((TaskContextImpl) this.context.getTaskContext()).getJobModel()).thenReturn(jobModel);
    OperatorImplGraph opImplGraph =
        new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), this.context, mock(Clock.class));

    InputOperatorImpl inputOpImpl = opImplGraph.getInputOperator(new SystemStream(inputSystem, inputPhysicalName));
    assertEquals(1, inputOpImpl.registeredOperators.size());

    OperatorImpl partitionByOpImpl = (PartitionByOperatorImpl) inputOpImpl.registeredOperators.iterator().next();
    assertEquals(0, partitionByOpImpl.registeredOperators.size()); // is terminal but paired with an input operator
    assertEquals(OpCode.PARTITION_BY, partitionByOpImpl.getOperatorSpec().getOpCode());

    InputOperatorImpl repartitionedInputOpImpl =
        opImplGraph.getInputOperator(new SystemStream(intermediateSystem, intermediateStreamId));
    assertEquals(1, repartitionedInputOpImpl.registeredOperators.size());

    OperatorImpl sendToOpImpl = (OutputOperatorImpl) repartitionedInputOpImpl.registeredOperators.iterator().next();
    assertEquals(0, sendToOpImpl.registeredOperators.size());
    assertEquals(OpCode.SEND_TO, sendToOpImpl.getOperatorSpec().getOpCode());
  }

  @Test
  public void testBroadcastChain() {
    String inputStreamId = "input";
    String inputSystem = "input-system";
    String inputPhysicalName = "input-stream";
    HashMap<String, String> configMap = new HashMap<>();
    configMap.put(JobConfig.JOB_NAME, "test-job");
    configMap.put(JobConfig.JOB_ID, "1");
    StreamTestUtils.addStreamConfigs(configMap, inputStreamId, inputSystem, inputPhysicalName);
    Config config = new MapConfig(configMap);
    when(this.context.getJobContext().getConfig()).thenReturn(config);
    StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> {
      GenericSystemDescriptor sd = new GenericSystemDescriptor(inputSystem, "mockFactoryClass");
      GenericInputDescriptor inputDescriptor = sd.getInputDescriptor(inputStreamId, mock(Serde.class));
      MessageStream<Object> inputStream = appDesc.getInputStream(inputDescriptor);
      inputStream.filter(mock(FilterFunction.class));
      inputStream.map(mock(MapFunction.class));
    }, config);

    OperatorImplGraph opImplGraph =
        new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), this.context, mock(Clock.class));

    InputOperatorImpl inputOpImpl = opImplGraph.getInputOperator(new SystemStream(inputSystem, inputPhysicalName));
    assertEquals(2, inputOpImpl.registeredOperators.size());
    assertTrue(inputOpImpl.registeredOperators.stream()
        .anyMatch(opImpl -> ((OperatorImpl) opImpl).getOperatorSpec().getOpCode() == OpCode.FILTER));
    assertTrue(inputOpImpl.registeredOperators.stream()
        .anyMatch(opImpl -> ((OperatorImpl) opImpl).getOperatorSpec().getOpCode() == OpCode.MAP));
  }

  @Test
  public void testMergeChain() {
    String inputStreamId = "input";
    String inputSystem = "input-system";
    StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> {
      GenericSystemDescriptor sd = new GenericSystemDescriptor(inputSystem, "mockFactoryClass");
      GenericInputDescriptor inputDescriptor = sd.getInputDescriptor(inputStreamId, mock(Serde.class));
      MessageStream<Object> inputStream = appDesc.getInputStream(inputDescriptor);
      MessageStream<Object> stream1 = inputStream.filter(mock(FilterFunction.class));
      MessageStream<Object> stream2 = inputStream.map(mock(MapFunction.class));
      stream1.merge(Collections.singleton(stream2))
          .map(new TestMapFunction<Object, Object>("test-map-1", (Function & Serializable) m -> m));
    }, getConfig());

    TaskName mockTaskName = mock(TaskName.class);
    TaskModel taskModel = mock(TaskModel.class);
    when(taskModel.getTaskName()).thenReturn(mockTaskName);
    when(this.context.getTaskContext().getTaskModel()).thenReturn(taskModel);

    OperatorImplGraph opImplGraph =
        new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), this.context, mock(Clock.class));

    Set<OperatorImpl> opSet = opImplGraph.getAllInputOperators().stream().collect(HashSet::new,
      (s, op) -> addOperatorRecursively(s, op), HashSet::addAll);
    Object[] mergeOps = opSet.stream().filter(op -> op.getOperatorSpec().getOpCode() == OpCode.MERGE).toArray();
    assertEquals(1, mergeOps.length);
    assertEquals(1, ((OperatorImpl) mergeOps[0]).registeredOperators.size());
    OperatorImpl mapOp = (OperatorImpl) ((OperatorImpl) mergeOps[0]).registeredOperators.iterator().next();
    assertEquals(mapOp.getOperatorSpec().getOpCode(), OpCode.MAP);

    // verify that the DAG after merge is only traversed & initialized once
    assertEquals(TestMapFunction.getInstanceByTaskName(mockTaskName, "test-map-1").numInitCalled, 1);
  }

  @Test
  public void testJoinChain() {
    String inputStreamId1 = "input1";
    String inputStreamId2 = "input2";
    String inputSystem = "input-system";
    String inputPhysicalName1 = "input-stream1";
    String inputPhysicalName2 = "input-stream2";
    HashMap<String, String> configs = new HashMap<>();
    configs.put(JobConfig.JOB_NAME, "jobName");
    configs.put(JobConfig.JOB_ID, "jobId");
    StreamTestUtils.addStreamConfigs(configs, inputStreamId1, inputSystem, inputPhysicalName1);
    StreamTestUtils.addStreamConfigs(configs, inputStreamId2, inputSystem, inputPhysicalName2);
    Config config = new MapConfig(configs);
    when(this.context.getJobContext().getConfig()).thenReturn(config);

    Integer joinKey = new Integer(1);
    Function<Object, Integer> keyFn = (Function & Serializable) m -> joinKey;
    JoinFunction testJoinFunction = new TestJoinFunction("jobName-jobId-join-j1",
        (BiFunction & Serializable) (m1, m2) -> KV.of(m1, m2), keyFn, keyFn);

    StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> {
      GenericSystemDescriptor sd = new GenericSystemDescriptor(inputSystem, "mockFactoryClass");
      GenericInputDescriptor inputDescriptor1 = sd.getInputDescriptor(inputStreamId1, mock(Serde.class));
      GenericInputDescriptor inputDescriptor2 = sd.getInputDescriptor(inputStreamId2, mock(Serde.class));
      MessageStream<Object> inputStream1 = appDesc.getInputStream(inputDescriptor1);
      MessageStream<Object> inputStream2 = appDesc.getInputStream(inputDescriptor2);

      inputStream1.join(inputStream2, testJoinFunction,
          mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(1), "j1");
    }, config);

    TaskName mockTaskName = mock(TaskName.class);
    TaskModel taskModel = mock(TaskModel.class);
    when(taskModel.getTaskName()).thenReturn(mockTaskName);
    when(this.context.getTaskContext().getTaskModel()).thenReturn(taskModel);

    KeyValueStore mockLeftStore = mock(KeyValueStore.class);
    when(this.context.getTaskContext().getStore(eq("jobName-jobId-join-j1-L"))).thenReturn(mockLeftStore);
    KeyValueStore mockRightStore = mock(KeyValueStore.class);
    when(this.context.getTaskContext().getStore(eq("jobName-jobId-join-j1-R"))).thenReturn(mockRightStore);
    OperatorImplGraph opImplGraph =
        new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), this.context, mock(Clock.class));

    // verify that join function is initialized once.
    assertEquals(TestJoinFunction.getInstanceByTaskName(mockTaskName, "jobName-jobId-join-j1").numInitCalled, 1);

    InputOperatorImpl inputOpImpl1 = opImplGraph.getInputOperator(new SystemStream(inputSystem, inputPhysicalName1));
    InputOperatorImpl inputOpImpl2 = opImplGraph.getInputOperator(new SystemStream(inputSystem, inputPhysicalName2));
    PartialJoinOperatorImpl leftPartialJoinOpImpl =
        (PartialJoinOperatorImpl) inputOpImpl1.registeredOperators.iterator().next();
    PartialJoinOperatorImpl rightPartialJoinOpImpl =
        (PartialJoinOperatorImpl) inputOpImpl2.registeredOperators.iterator().next();

    assertEquals(leftPartialJoinOpImpl.getOperatorSpec(), rightPartialJoinOpImpl.getOperatorSpec());
    assertNotSame(leftPartialJoinOpImpl, rightPartialJoinOpImpl);

    // verify that left partial join operator calls getFirstKey
    Object mockLeftMessage = mock(Object.class);
    long currentTimeMillis = System.currentTimeMillis();
    when(mockLeftStore.get(eq(joinKey))).thenReturn(new TimestampedValue<>(mockLeftMessage, currentTimeMillis));
    IncomingMessageEnvelope leftMessage = new IncomingMessageEnvelope(mock(SystemStreamPartition.class), "", "", mockLeftMessage);
    inputOpImpl1.onMessage(leftMessage, mock(MessageCollector.class), mock(TaskCoordinator.class));

    // verify that right partial join operator calls getSecondKey
    Object mockRightMessage = mock(Object.class);
    when(mockRightStore.get(eq(joinKey))).thenReturn(new TimestampedValue<>(mockRightMessage, currentTimeMillis));
    IncomingMessageEnvelope rightMessage = new IncomingMessageEnvelope(mock(SystemStreamPartition.class), "", "", mockRightMessage);
    inputOpImpl2.onMessage(rightMessage, mock(MessageCollector.class), mock(TaskCoordinator.class));


    // verify that the join function apply is called with the correct messages on match
    assertEquals(((TestJoinFunction) TestJoinFunction.getInstanceByTaskName(mockTaskName, "jobName-jobId-join-j1")).joinResults.size(), 1);
    KV joinResult = (KV) ((TestJoinFunction) TestJoinFunction.getInstanceByTaskName(mockTaskName, "jobName-jobId-join-j1")).joinResults.iterator().next();
    assertEquals(joinResult.getKey(), mockLeftMessage);
    assertEquals(joinResult.getValue(), mockRightMessage);
  }

  @Test
  public void testOperatorGraphInitAndClose() {
    String inputStreamId1 = "input1";
    String inputStreamId2 = "input2";
    String inputSystem = "input-system";

    TaskName mockTaskName = mock(TaskName.class);
    TaskModel taskModel = mock(TaskModel.class);
    when(taskModel.getTaskName()).thenReturn(mockTaskName);
    when(this.context.getTaskContext().getTaskModel()).thenReturn(taskModel);

    StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> {
      GenericSystemDescriptor sd = new GenericSystemDescriptor(inputSystem, "mockFactoryClass");
      GenericInputDescriptor inputDescriptor1 = sd.getInputDescriptor(inputStreamId1, mock(Serde.class));
      GenericInputDescriptor inputDescriptor2 = sd.getInputDescriptor(inputStreamId2, mock(Serde.class));
      MessageStream<Object> inputStream1 = appDesc.getInputStream(inputDescriptor1);
      MessageStream<Object> inputStream2 = appDesc.getInputStream(inputDescriptor2);

      Function mapFn = (Function & Serializable) m -> m;
      inputStream1.map(new TestMapFunction<Object, Object>("1", mapFn))
          .map(new TestMapFunction<Object, Object>("2", mapFn));

      inputStream2.map(new TestMapFunction<Object, Object>("3", mapFn))
          .map(new TestMapFunction<Object, Object>("4", mapFn));
    }, getConfig());

    OperatorImplGraph opImplGraph = new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), this.context, SystemClock.instance());

    List<String> initializedOperators = BaseTestFunction.getInitListByTaskName(mockTaskName);

    // Assert that initialization occurs in topological order.
    assertEquals(initializedOperators.get(0), "1");
    assertEquals(initializedOperators.get(1), "2");
    assertEquals(initializedOperators.get(2), "3");
    assertEquals(initializedOperators.get(3), "4");

    // Assert that finalization occurs in reverse topological order.
    opImplGraph.close();
    List<String> closedOperators = BaseTestFunction.getCloseListByTaskName(mockTaskName);
    assertEquals(closedOperators.get(0), "4");
    assertEquals(closedOperators.get(1), "3");
    assertEquals(closedOperators.get(2), "2");
    assertEquals(closedOperators.get(3), "1");
  }

  @Test
  public void testGetStreamToConsumerTasks() {
    String system = "test-system";
    String streamId0 = "test-stream-0";
    String streamId1 = "test-stream-1";

    HashMap<String, String> configs = new HashMap<>();
    configs.put(JobConfig.JOB_NAME, "test-app");
    configs.put(JobConfig.JOB_DEFAULT_SYSTEM, "test-system");
    StreamTestUtils.addStreamConfigs(configs, streamId0, system, streamId0);
    StreamTestUtils.addStreamConfigs(configs, streamId1, system, streamId1);
    Config config = new MapConfig(configs);
    when(this.context.getJobContext().getConfig()).thenReturn(config);

    SystemStreamPartition ssp0 = new SystemStreamPartition(system, streamId0, new Partition(0));
    SystemStreamPartition ssp1 = new SystemStreamPartition(system, streamId0, new Partition(1));
    SystemStreamPartition ssp2 = new SystemStreamPartition(system, streamId1, new Partition(0));

    TaskName task0 = new TaskName("Task 0");
    TaskName task1 = new TaskName("Task 1");
    Set<SystemStreamPartition> ssps = new HashSet<>();
    ssps.add(ssp0);
    ssps.add(ssp2);
    TaskModel tm0 = new TaskModel(task0, ssps, new Partition(0));
    ContainerModel cm0 = new ContainerModel("c0", Collections.singletonMap(task0, tm0));
    TaskModel tm1 = new TaskModel(task1, Collections.singleton(ssp1), new Partition(1));
    ContainerModel cm1 = new ContainerModel("c1", Collections.singletonMap(task1, tm1));

    Map<String, ContainerModel> cms = new HashMap<>();
    cms.put(cm0.getId(), cm0);
    cms.put(cm1.getId(), cm1);

    JobModel jobModel = new JobModel(config, cms);
    Multimap<SystemStream, String> streamToTasks = OperatorImplGraph.getStreamToConsumerTasks(jobModel);
    assertEquals(streamToTasks.get(ssp0.getSystemStream()).size(), 2);
    assertEquals(streamToTasks.get(ssp2.getSystemStream()).size(), 1);
  }

  @Test
  public void testGetOutputToInputStreams() {
    String inputStreamId1 = "input1";
    String inputStreamId2 = "input2";
    String inputStreamId3 = "input3";
    String inputSystem = "input-system";

    String outputStreamId1 = "output1";
    String outputStreamId2 = "output2";
    String outputSystem = "output-system";

    String intStreamId1 = "test-app-1-partition_by-p1";
    String intStreamId2 = "test-app-1-partition_by-p2";
    String intSystem = "test-system";

    HashMap<String, String> configs = new HashMap<>();
    configs.put(JobConfig.JOB_NAME, "test-app");
    configs.put(JobConfig.JOB_DEFAULT_SYSTEM, intSystem);
    StreamTestUtils.addStreamConfigs(configs, inputStreamId1, inputSystem, inputStreamId1);
    StreamTestUtils.addStreamConfigs(configs, inputStreamId2, inputSystem, inputStreamId2);
    StreamTestUtils.addStreamConfigs(configs, inputStreamId3, inputSystem, inputStreamId3);
    StreamTestUtils.addStreamConfigs(configs, outputStreamId1, outputSystem, outputStreamId1);
    StreamTestUtils.addStreamConfigs(configs, outputStreamId2, outputSystem, outputStreamId2);
    Config config = new MapConfig(configs);
    when(this.context.getJobContext().getConfig()).thenReturn(config);

    StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> {
      GenericSystemDescriptor isd = new GenericSystemDescriptor(inputSystem, "mockFactoryClass");
      GenericInputDescriptor inputDescriptor1 = isd.getInputDescriptor(inputStreamId1, mock(Serde.class));
      GenericInputDescriptor inputDescriptor2 = isd.getInputDescriptor(inputStreamId2, mock(Serde.class));
      GenericInputDescriptor inputDescriptor3 = isd.getInputDescriptor(inputStreamId3, mock(Serde.class));
      GenericSystemDescriptor osd = new GenericSystemDescriptor(outputSystem, "mockFactoryClass");
      GenericOutputDescriptor outputDescriptor1 = osd.getOutputDescriptor(outputStreamId1, mock(Serde.class));
      GenericOutputDescriptor outputDescriptor2 = osd.getOutputDescriptor(outputStreamId2, mock(Serde.class));
      MessageStream messageStream1 = appDesc.getInputStream(inputDescriptor1).map(m -> m);
      MessageStream messageStream2 = appDesc.getInputStream(inputDescriptor2).filter(m -> true);
      MessageStream messageStream3 =
          appDesc.getInputStream(inputDescriptor3)
              .filter(m -> true)
              .partitionBy(m -> "m", m -> m, mock(KVSerde.class),  "p1")
              .map(m -> m);
      OutputStream<Object> outputStream1 = appDesc.getOutputStream(outputDescriptor1);
      OutputStream<Object> outputStream2 = appDesc.getOutputStream(outputDescriptor2);

      messageStream1
          .join(messageStream2, mock(JoinFunction.class),
              mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(2), "j1")
          .partitionBy(m -> "m", m -> m, mock(KVSerde.class), "p2")
          .sendTo(outputStream1);
      messageStream3
          .join(messageStream2, mock(JoinFunction.class),
              mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(1), "j2")
          .sendTo(outputStream2);
    }, config);

    Multimap<SystemStream, SystemStream> outputToInput =
        OperatorImplGraph.getIntermediateToInputStreamsMap(graphSpec.getOperatorSpecGraph(), new StreamConfig(config));
    Collection<SystemStream> inputs = outputToInput.get(new SystemStream(intSystem, intStreamId2));
    assertEquals(inputs.size(), 2);
    assertTrue(inputs.contains(new SystemStream(inputSystem, inputStreamId1)));
    assertTrue(inputs.contains(new SystemStream(inputSystem, inputStreamId2)));

    inputs = outputToInput.get(new SystemStream(intSystem, intStreamId1));
    assertEquals(inputs.size(), 1);
    assertEquals(inputs.iterator().next(), new SystemStream(inputSystem, inputStreamId3));
  }

  @Test
  public void testGetProducerTaskCountForIntermediateStreams() {
    String inputStreamId1 = "input1";
    String inputStreamId2 = "input2";
    String inputStreamId3 = "input3";
    String inputSystem1 = "system1";
    String inputSystem2 = "system2";

    SystemStream input1 = new SystemStream("system1", "intput1");
    SystemStream input2 = new SystemStream("system2", "intput2");
    SystemStream input3 = new SystemStream("system2", "intput3");

    SystemStream int1 = new SystemStream("system1", "int1");
    SystemStream int2 = new SystemStream("system1", "int2");


    /**
     * the task assignment looks like the following:
     *
     * input1 -----> task0, task1 -----> int1
     *                                    ^
     * input2 ------> task1, task2--------|
     *                                    v
     * input3 ------> task1 -----------> int2
     *
     */
    String task0 = "Task 0";
    String task1 = "Task 1";
    String task2 = "Task 2";

    Multimap<SystemStream, String> streamToConsumerTasks = HashMultimap.create();
    streamToConsumerTasks.put(input1, task0);
    streamToConsumerTasks.put(input1, task1);
    streamToConsumerTasks.put(input2, task1);
    streamToConsumerTasks.put(input2, task2);
    streamToConsumerTasks.put(input3, task1);
    streamToConsumerTasks.put(int1, task0);
    streamToConsumerTasks.put(int1, task1);
    streamToConsumerTasks.put(int2, task0);

    Multimap<SystemStream, SystemStream> intermediateToInputStreams = HashMultimap.create();
    intermediateToInputStreams.put(int1, input1);
    intermediateToInputStreams.put(int1, input2);

    intermediateToInputStreams.put(int2, input2);
    intermediateToInputStreams.put(int2, input3);

    Map<SystemStream, Integer> counts = OperatorImplGraph.getProducerTaskCountForIntermediateStreams(
        streamToConsumerTasks, intermediateToInputStreams);
    assertTrue(counts.get(int1) == 3);
    assertTrue(counts.get(int2) == 2);
  }

  private void addOperatorRecursively(HashSet<OperatorImpl> s, OperatorImpl op) {
    List<OperatorImpl> operators = new ArrayList<>();
    operators.add(op);
    while (!operators.isEmpty()) {
      OperatorImpl opImpl = operators.remove(0);
      s.add(opImpl);
      if (!opImpl.registeredOperators.isEmpty()) {
        operators.addAll(opImpl.registeredOperators);
      }
    }
  }

  private Config getConfig() {
    HashMap<String, String> configMap = new HashMap<>();
    configMap.put(JobConfig.JOB_NAME, "test-job");
    configMap.put(JobConfig.JOB_ID, "1");
    return new MapConfig(configMap);
  }

  private static class TestMapFunction<M, OM> extends BaseTestFunction implements MapFunction<M, OM> {
    final Function<M, OM> mapFn;

    public TestMapFunction(String opId, Function<M, OM> mapFn) {
      super(opId);
      this.mapFn = mapFn;
    }

    @Override
    public OM apply(M message) {
      return this.mapFn.apply(message);
    }
  }

  private static class TestJoinFunction<K, M, JM, RM> extends BaseTestFunction implements JoinFunction<K, M, JM, RM> {
    final BiFunction<M, JM, RM> joiner;
    final Function<M, K> firstKeyFn;
    final Function<JM, K> secondKeyFn;
    final Collection<RM> joinResults = new HashSet<>();

    public TestJoinFunction(String opId, BiFunction<M, JM, RM> joiner, Function<M, K> firstKeyFn, Function<JM, K> secondKeyFn) {
      super(opId);
      this.joiner = joiner;
      this.firstKeyFn = firstKeyFn;
      this.secondKeyFn = secondKeyFn;
    }

    @Override
    public RM apply(M message, JM otherMessage) {
      RM result = this.joiner.apply(message, otherMessage);
      this.joinResults.add(result);
      return result;
    }

    @Override
    public K getFirstKey(M message) {
      return this.firstKeyFn.apply(message);
    }

    @Override
    public K getSecondKey(JM message) {
      return this.secondKeyFn.apply(message);
    }
  }

  private static abstract class BaseTestFunction implements InitableFunction, ClosableFunction, Serializable {
    static Map<TaskName, Map<String, BaseTestFunction>> perTaskFunctionMap = new HashMap<>();
    static Map<TaskName, List<String>> perTaskInitList = new HashMap<>();
    static Map<TaskName, List<String>> perTaskCloseList = new HashMap<>();
    int numInitCalled = 0;
    int numCloseCalled = 0;
    TaskName taskName = null;
    final String opId;

    public BaseTestFunction(String opId) {
      this.opId = opId;
    }

    static public void reset() {
      perTaskFunctionMap.clear();
      perTaskCloseList.clear();
      perTaskInitList.clear();
    }

    static public BaseTestFunction getInstanceByTaskName(TaskName taskName, String opId) {
      return perTaskFunctionMap.get(taskName).get(opId);
    }

    static public List<String> getInitListByTaskName(TaskName taskName) {
      return perTaskInitList.get(taskName);
    }

    static public List<String> getCloseListByTaskName(TaskName taskName) {
      return perTaskCloseList.get(taskName);
    }

    @Override
    public void close() {
      if (this.taskName == null) {
        throw new IllegalStateException("Close called before init");
      }
      if (perTaskFunctionMap.get(this.taskName) == null || !perTaskFunctionMap.get(this.taskName).containsKey(opId)) {
        throw new IllegalStateException("Close called before init");
      }

      if (perTaskCloseList.get(this.taskName) == null) {
        perTaskCloseList.put(taskName, new ArrayList<>(Collections.singletonList(opId)));
      } else {
        perTaskCloseList.get(taskName).add(opId);
      }

      this.numCloseCalled++;
    }

    @Override
    public void init(Context context) {
      TaskName taskName = context.getTaskContext().getTaskModel().getTaskName();
      if (perTaskFunctionMap.get(taskName) == null) {
        perTaskFunctionMap.put(taskName, new HashMap<>(Collections.singletonMap(opId, BaseTestFunction.this)));
      } else {
        if (perTaskFunctionMap.get(taskName).containsKey(opId)) {
          throw new IllegalStateException(String.format("Multiple init called for op %s in the same task instance %s", opId, this.taskName.getTaskName()));
        }
        perTaskFunctionMap.get(taskName).put(opId, this);
      }
      if (perTaskInitList.get(taskName) == null) {
        perTaskInitList.put(taskName, new ArrayList<>(Collections.singletonList(opId)));
      } else {
        perTaskInitList.get(taskName).add(opId);
      }
      this.taskName = taskName;
      this.numInitCalled++;
    }
  }
}
