【MediaPipe】HelloWorldのプログラム動作/処理を解析してみた

2020.05.26

カフェチームの山本です。

前回はHelloWorldのプログラムを参考に、MediaPipeのフレームワークとしての動作/構成を学びました。

今回は、HelloWorldのプログラムの内部を詳細に見ながら、グラフに対してデータを入出力する方法 ソースコードがどのように実行されているかを学びます。

今回も学習しただけであるため、特に新しい知見や結論はありませんが、ご参考になれば幸いです。

MediaPipeに関連する記事はこちらにまとめてあります。

【MediaPipe】投稿記事まとめ

HelloWorldの入出力

前回の記事の「HelloWorldのプログラム」で記載したように、HelloWorldのプログラムは、mediapipe/examples/desktop/hello_world/hello_world.cc から実行されています。中を見てみると、入出力に関して以下のことがわかります。

  • 出力準備:初期化したGraphに対して、出力を受け取る用の OutputStreamPoller をくっつける。この際、Graph中の出力ストリームである "out" にアタッチする。(43~44行目)
  • 入力:処理を開始したGraphに対して、入力ストリームである "in" を指定して、データを入力する。この際、データは Packet という型にラップし、タイムスタンプを付与する。(48~49行目)
  • 出力:ポーラを介して、グラフ中のデータがなくなるまで出力を受け取る。この際、データの型は Packet で受け取るため、データを変換する。(55~57行目)

mediapipe/examples/desktop/hello_world/hello_world.cc

#include "mediapipe/framework/calculator_graph.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status.h"

namespace mediapipe {

::mediapipe::Status PrintHelloWorld() {
  // Configures a simple graph, which concatenates 2 PassThroughCalculators.
  CalculatorGraphConfig config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
    input_stream: "in"
    output_stream: "out"
    node {
      calculator: "PassThroughCalculator"
      input_stream: "in"
      output_stream: "out1"
    }
    node {
      calculator: "PassThroughCalculator"
      input_stream: "out1"
      output_stream: "out"
    }
  )");

  CalculatorGraph graph;
  MP_RETURN_IF_ERROR(graph.Initialize(config));
  ASSIGN_OR_RETURN(OutputStreamPoller poller,
                   graph.AddOutputStreamPoller("out"));
  MP_RETURN_IF_ERROR(graph.StartRun({}));
  // Give 10 input packets that contains the same std::string "Hello World!".
  for (int i = 0; i < 10; ++i) {
    MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
        "in", MakePacket<std::string>("Hello World!").At(Timestamp(i))));
  }
  // Close the input stream "in".
  MP_RETURN_IF_ERROR(graph.CloseInputStream("in"));
  mediapipe::Packet packet;
  // Get the output packets std::string.
  while (poller.Next(&packet)) {
    LOG(INFO) << packet.Get<std::string>();
  }
  return graph.WaitUntilDone();
}
}  // namespace mediapipe

Graph内のプログラム

Graph内のPassThroughCalculatorですが、mediapipe/examples/desktop/hello_world/BUILDを見ると、mediapipe/calculators/core内のpass_through_calculatorが実行されている(いそうな)ことがわかります。

mediapipe/examples/desktop/hello_world/BUILD

cc_binary(
    name = "hello_world",
    srcs = ["hello_world.cc"],
    visibility = ["//visibility:public"],
    deps = [
        "//mediapipe/calculators/core:pass_through_calculator",
        "//mediapipe/framework:calculator_graph",
        "//mediapipe/framework/port:logging",
        "//mediapipe/framework/port:parse_text_proto",
        "//mediapipe/framework/port:status",
    ],
)

mediapipe/calculators/core/BUILDを見ると、pass_through_calculatorは同じファイルのpass_through_calculator.ccを参照していることがわかります。

mediapipe/calculators/core/BUILD

cc_library(
    name = "pass_through_calculator",
    srcs = ["pass_through_calculator.cc"],
    visibility = [
        "//visibility:public",
    ],
    deps = [
        "//mediapipe/framework:calculator_framework",
        "//mediapipe/framework/port:status",
    ],
    alwayslink = 1,
)

mediapipe/calculators/core/pass_through_calculator.ccを見ると、以下のことがわかります。

  • PassThroughCalculatorが定義されており、REGISTER_CALCULATORで登録され、hello_world.ccのGraphの定義で使用されている。(29, 96行目)
  • PassThroughCalculatorはCalculatorBaseを継承しており、静的関数GetCongractと、Open・Processがオーバライドされている。Closeは定義されていない。(31, 62, 79行目)
  • ProcessはCalculatorContextという型の入力を受け取り、cc->Inputs().Get(id).Value()のようにして入力を受け取り、cc->Outputs().Get(id).AddPacket()のようにして出力する(90行目)
    • この id には 0が入ります("in" や "out1" のような入力/出力の数に対応します)。cc->Outputs().Get(id).Name()で"in"や"out1"が得られます。

mediapipe/calculators/core/pass_through_calculator.cc

class PassThroughCalculator : public CalculatorBase {
 public:
  static ::mediapipe::Status GetContract(CalculatorContract* cc) {
    if (!cc->Inputs().TagMap()->SameAs(*cc->Outputs().TagMap())) {
      return ::mediapipe::InvalidArgumentError(
          "Input and output streams to PassThroughCalculator must use "
          "matching tags and indexes.");
    }
    for (CollectionItemId id = cc->Inputs().BeginId();
         id < cc->Inputs().EndId(); ++id) {
      cc->Inputs().Get(id).SetAny();
      cc->Outputs().Get(id).SetSameAs(&cc->Inputs().Get(id));
    }
    for (CollectionItemId id = cc->InputSidePackets().BeginId();
         id < cc->InputSidePackets().EndId(); ++id) {
      cc->InputSidePackets().Get(id).SetAny();
    }
    if (cc->OutputSidePackets().NumEntries() != 0) {
      if (!cc->InputSidePackets().TagMap()->SameAs(
              *cc->OutputSidePackets().TagMap())) {
        return ::mediapipe::InvalidArgumentError(
            "Input and output side packets to PassThroughCalculator must use "
            "matching tags and indexes.");
      }
      for (CollectionItemId id = cc->InputSidePackets().BeginId();
           id < cc->InputSidePackets().EndId(); ++id) {
        cc->OutputSidePackets().Get(id).SetSameAs(
            &cc->InputSidePackets().Get(id));
      }
    }
    return ::mediapipe::OkStatus();
  }

  ::mediapipe::Status Open(CalculatorContext* cc) final {
    for (CollectionItemId id = cc->Inputs().BeginId();
         id < cc->Inputs().EndId(); ++id) {
      if (!cc->Inputs().Get(id).Header().IsEmpty()) {
        cc->Outputs().Get(id).SetHeader(cc->Inputs().Get(id).Header());
      }
    }
    if (cc->OutputSidePackets().NumEntries() != 0) {
      for (CollectionItemId id = cc->InputSidePackets().BeginId();
           id < cc->InputSidePackets().EndId(); ++id) {
        cc->OutputSidePackets().Get(id).Set(cc->InputSidePackets().Get(id));
      }
    }
    cc->SetOffset(TimestampDiff(0));
    return ::mediapipe::OkStatus();
  }

  ::mediapipe::Status Process(CalculatorContext* cc) final {
    cc->GetCounter("PassThrough")->Increment();
    if (cc->Inputs().NumEntries() == 0) {
      return tool::StatusStop();
    }
    for (CollectionItemId id = cc->Inputs().BeginId();
         id < cc->Inputs().EndId(); ++id) {
      if (!cc->Inputs().Get(id).IsEmpty()) {
        VLOG(3) << "Passing " << cc->Inputs().Get(id).Name() << " to "
                << cc->Outputs().Get(id).Name() << " at "
                << cc->InputTimestamp().DebugString();
        cc->Outputs().Get(id).AddPacket(cc->Inputs().Get(id).Value());
      }
    }
    return ::mediapipe::OkStatus();
  }
};
REGISTER_CALCULATOR(PassThroughCalculator);

まとめ

今回は、HelloWorldのプログラムがどのように動いているかを調べるため、ソースコードを追って見てみました。

次回は、Multi Hand Trackingのプログラムを見ていきます。