[ChatGPT] 冷蔵庫内の写真から「おすすめレシピ」を受け取ってみました 〜食品は、Segment Anything と 転移学習した分類モデルで検出してます〜

[ChatGPT] 冷蔵庫内の写真から「おすすめレシピ」を受け取ってみました 〜食品は、Segment Anything と 転移学習した分類モデルで検出してます〜

Clock Icon2023.05.03

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

1 はじめに

CX 事業本部 delivery部の平内(SIN)です。

ChatGPTを利用すると、いくつかの食品を提示して「おすすめのレシピ」を答えてもらうことができます。

そして、冷蔵庫内の写真から、それを行う方法が、The Multimodal And Modular Ai Chef: Complex Recipe Generation From Imageryで紹介されています。

https://arxiv.org/pdf/2304.02016.pdf

GPT3.5では、画像入力ができませんが、画像を解釈するための物体検出モデルを前段に組み合わせることで、面白い体験ができるものだと感心します。

今回は、これを「私もやってみたい」ということで、真似事ですが、試してみた記録です。

なお、上記の論文では、食品検出用のモデルは、オープンソースデータセットのai-cook-lcv4dを使用して、YOLOv5をファインチューニングすることで作成されていましたが、ai-cookデータセットは、手元の冷蔵庫や食品とうまくマッチしなかったので、Segment Anythingと転移学習した分類モデル(Resnet)の組み合わせで、行うことにしました。

最初に、動作している様子です。

冷蔵庫内を撮影した写真を入力すると、セグメンテーションで切り取られ、それぞれが、どの食品に該当するかを分類モデルで判定しています。そして、取得できた食品の一覧をプロンプトに入れて、gpt-3.5-turboでレシピの提案を出力しています。

2 Segment Anything

(1) SamAutomaticMaskGenerator

Segment Anything Model(SAM) では、Automatically generating objectで画面全体を自動的にセグメンテーションできます。

mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,
    pred_iou_thresh=0.9,
    stability_score_thresh=0.96,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=100,
)
masks = mask_generator.generate(image)

SamAutomaticMaskGenerator() では、各種の引数を指定できますが、pred_iou_threshを変更することで、セグメンテーションの具合を調整できます。この値は、デフォルトで0.88ですが、色々試した結果、今回の960×720の画像では、0.98 が丁度いい感じでした。

mask_generator = SamAutomaticMaskGenerator(
    model=sam, pred_iou_thresh=0.98
)
masks = mask_generator.generate(image)

参考: https://replicate.com/pablodawson/segment-anything-automatic/api

(2) Mask

取得したマスクは、以下のような内容になっています。

* segmentation : マスク
* area : マスクのピクセル面積
* bbox : XYWH形式の境界ボックス
* predicted_iou : マスクの品質に対する予測
* point_coords : このマスクを生成したサンプリングされた入力ポイント
* stability_score : マスク品質の追加尺度
* crop_box : このマスクをXYWH形式で生成するために使用される画像のトリミング

ここで、areaが、オブジェクトのサイズを表現しているため、適当な大きさのもののみを検出対象としました。

# 一定のサイズのものだけを抽出する
max_pixels = 13000
min_pixels = 2300

for index in range(sam.length):
    mask, pixels = sam.get(index)

    # 一定範囲のサイズのみ対象にする
    if pixels < min_pixels or max_pixels < pixels:
        continue

segmentationは、2次元のboolとなっています。

shape:(960, 720) type:<class 'numpy.ndarray'> dtype:bool ndim:2

これを使用して、2値の黒画像にプロットすることで、輪郭を取得したり、元画像の対象外部分を白く塗ることで、切り取り用の画像を生成しています。(areaクラスは、x,yのそれぞれ最大値と最小値を取得しているものです。)

# 輪郭検出用の2値のテンポラリ画像
mono_image = np.full(np.array([height, width, 1]), 255, dtype=np.uint8)
# 個々の食品を切取るためのテンポラリ画像
food_image = image.copy()

area = Area(width, height)
for y in range(height):
    for x in range(width):
        if mask[y][x]:
            mono_image[y][x] = 0  # 2値画像は、マスク部分を黒にする
            area.set(x, y)
        else:
            food_image[y][x] = [255, 255, 255]  # 食品切取り画像は、マスク部分以外を白にする
# 検出範囲
x1, y1, x2, y2 = area.get()
# 食品の輪郭を取得する
contours, _ = cv2.findContours(
    mono_image, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
)
# 食品画像を切取る
food_image = food_image[y1:y2, x1:x2]

3 分類モデル

(1) 転移学習

分類モデルは、Resnetを転移学習して作成しました。

コードは、Pytorchの Tutorials に置かれている finetuning_torchvision_models_tutorial.ipynb をほとんどそのまま使用しています。

(2) データセット

このモデルで推論する対象は、セグメンテーションで切り取られた画像であり、背景などのノイズが無いためか、データセットとしては、学習用・検証用あわせて10枚ほどしか準備されていませんが、利用可能な精度を得ることができました。

※ 必要なデータセットの画像に関しては、食品の種類や、類似性に大きく依存するので、10枚というのは、あくまで今回の場合です

(3) 対象外の排除

セグメンテーションは、適当なサイズであれば、食品以外のものも検出されてしまいます。

そこで、そのような物体も分類モデルで判別できるようにデータセットに加えています。

# 分類モデルで食品名を推論する
class_name, probs = ic.inference(food_image)

# nodeで始まる名前は、食品では無い
if class_name.startswith("none") == False:
    data_list.append(Data(food_image, contours, class_name, probs))
    print(
        "{} {:.2f}% [{},{},{},{}]".format(
            get_name(class_name), probs, x1, y1, x2, y2
        )
    )

4 gpt-3.5-turbo

(1) prompt

gpt-3.5-turboのプロンプトは、以下のようになっています。

パラメータ food_listには、先のモデルで検出された食品の一覧「とり肉,キャベツ,ピーマン」などが入ります。

def _create_prompt(self, food_list):
    list_str = ""
    for food in food_list:
        list_str += "{},".format(food)

    order = """
あなたは高級料理店の専門シェフです。 ユーザから与えられた食材のリスト使用して1種類のレシピを教えてくれます。
レシピには、キャッチーなタイトル、調理するのに必要な時間、また、レシピと同様に、リストに成分部分を含めます
"""
    role = "role"
    content = "content"
    system = "system"
    user = "user"
    assistant = "assistant"

    prompt = []
    prompt.append({role: system, content: order})
    prompt.append({role: user, content: "食材リストは、{} です".format(list_str)})
    prompt.append({role: assistant, content: ""})
    return prompt

(2) openai.ChatCompletion.create

こちらは、通常の利用方法です。temperature は、ちょっと変化を楽しみたいという意味でう、0.7としてみました。

def recipe_recommend(self, food_list):
    prompt = self._create_prompt(food_list)
    try:
        response = openai.ChatCompletion.create(
            model="gpt-3.5-turbo", messages=prompt, temperature=0.7
        )
        print(response.choices[0]["message"]["content"].strip())
    except Exception as e:
        print("Exception:", e.args)

5 最後に

今回は、冷蔵庫内の写真を元に、ChatGPTで「おすすめレシピ」を受け取ってみました。LLMは、それだけでも凄いですが、マルチモーダルな組み合わせで、さらに、その世界感が広がることを実感できました。

また、Segment Anythingが、物体検出においてもゲームチェンジャーなのでは?と感じました。

動画で動作していたソースコードは、以下です。説明が不足している部分については、こちらを参照頂ければ幸いです。

index.py
import os
import math
import numpy as np
import cv2
import torch
import matplotlib.pyplot as plt
from torch.nn import functional as F
from torchvision import transforms
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
import japanize_matplotlib
import openai


# 分類モデル(resnet)
class ImageClassification:
    def __init__(self, device, model, label_file):
        print("init Image Classification")
        self.device = device

        self.model = torch.load(model)
        self.model.eval()
        self.model.to(self.device)

        with open(label_file, mode="r") as f:
            self.class_names = f.read().split(",")
        self.class_names.remove("")
        print(self.class_names)

    def inference(self, image):
        preprocess = transforms.Compose(
            [
                transforms.ToPILImage(),
                transforms.Resize(256),  # サイズ変更 (256, 256)
                transforms.CenterCrop(224),  # 中心で切り抜く (224, 224)
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )
        input_tensor = preprocess(image)
        input_batch = input_tensor.unsqueeze(0)
        input_batch = input_batch.to(self.device)

        with torch.no_grad():
            outputs = self.model(input_batch)

        batch_probs = F.softmax(outputs, dim=1)
        batch_probs, batch_indices = batch_probs.sort(dim=1, descending=True)

        for probs, indices in zip(batch_probs, batch_indices):
            for k in range(1):
                return self.class_names[indices[k]], probs[k]


# SAM
class SegmentAnything:
    def __init__(self, device, model_type, sam_checkpoint):
        print("init Segment Anything")
        self.device = device
        self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
        self.sam.to(self.device)

    def generat_masks(self, image):
        print("generate masks by Segment Anything")
        # 検出するマスクの品質 (default: 0.88)
        # pred_iou_thresh = 0.96 640*480の時ちょうどよかった
        pred_iou_thresh = 0.98  #  960×720の時はこれぐらい
        mask_generator_2 = SamAutomaticMaskGenerator(
            model=self.sam, pred_iou_thresh=pred_iou_thresh
        )
        masks = mask_generator_2.generate(image)
        self.sorted_anns = sorted(masks, key=(lambda x: x["area"]), reverse=True)
        self._length = len(masks)
        print("{} masks generated.".format(self._length))

    @property
    def length(self):
        return self._length

    def get(self, index):
        pixels = self.sorted_anns[index]["area"]
        mask = np.expand_dims(self.sorted_anns[index]["segmentation"], axis=-1)
        return mask, pixels


def read_image(file_name):
    image = cv2.imread(file_name)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    height, width = image.shape[:2]
    return image, height, width


# GPT
class GPT3:
    def __init__(self):
        openai.api_key = os.environ["OPENAI_API_KEY"]
        print("init gpt-3.5-turbo")

    def recipe_recommend(self, food_list):
        prompt = self._create_prompt(food_list)
        try:
            response = openai.ChatCompletion.create(
                model="gpt-3.5-turbo", messages=prompt, temperature=0.7
            )
            print(response.choices[0]["message"]["content"].strip())
        except Exception as e:
            print("Exception:", e.args)

    def _create_prompt(self, food_list):
        list_str = ""
        for food in food_list:
            list_str += "{},".format(food)

        order = """
あなたは高級料理店の専門シェフです。 ユーザから与えられた食材のリスト使用して1種類のレシピを教えてくれます。
レシピには、キャッチーなタイトル、調理するのに必要な時間、また、レシピと同様に、リストに成分部分を含めます
"""
        role = "role"
        content = "content"
        system = "system"
        user = "user"
        assistant = "assistant"

        prompt = []
        prompt.append({role: system, content: order})
        prompt.append({role: user, content: "食材リストは、{} です".format(list_str)})
        prompt.append({role: assistant, content: ""})
        return prompt


# エリアを検出するクラス
class Area:
    def __init__(self, width, height):
        self.x1 = width
        self.x2 = 0
        self.y1 = height
        self.y2 = 0

    def set(self, x, y):
        if x < self.x1:
            self.x1 = x
        if self.x2 < x:
            self.x2 = x
        if y < self.y1:
            self.y1 = y
        if self.y2 < y:
            self.y2 = y

    def get(self):
        return self.x1, self.y1, self.x2, self.y2


# 検出したデータを表現するクラス
class Data:
    def __init__(self, image, contours, class_name, probs):
        self._image = image
        self._contours = contours
        self._class_name = class_name
        self._probs = probs

    @property
    def image(self):
        return self._image

    @property
    def contours(self):
        return self._contours

    @property
    def class_name(self):
        return self._class_name

    @property
    def probs(self):
        return self._probs


# 名前の索引
def get_name(class_name):
    if class_name == "sausage":
        return "ソーセージ"
    if class_name == "cabbage":
        return "キャベツ"
    if class_name == "potato":
        return "じゃがいも"
    if class_name == "milk":
        return "牛乳"
    if class_name == "pumpkin":
        return "かぼちゃ"
    if class_name == "green_pepper":
        return "ピーマン"
    if class_name == "nasubi":
        return "なすび"
    if class_name == "paprika":
        return "パプリカ"
    if class_name == "carrot":
        return "人参"
    if class_name == "apple":
        return "りんご"
    if class_name == "chikin":
        return "とり肉"
    if class_name == "pork":
        return "豚肉"
    if class_name == "tomato":
        return "トマト"
    if class_name == "banana":
        return "バナナ"
    if class_name == "onion":
        return "玉ねぎ"
    return "食品ではない".format(class_name)


def job(filename, sam, ic, gpt3):
    print("==========================================================")
    print("おすすめレシピ")
    print("Configured in gpt-3.5-turbo, Segment Anything and resnet. ")
    print("==========================================================")
    # 画像の読み込み
    image, height, width = read_image(filename)
    print("read_image {} width:{} height:{}".format(filename, width, height))
    print("この写真を元に、おすすめのレシピを提案します。")
    plt.imshow(image)
    plt.axis("off")
    plt.show()

    # マスク生成
    sam.generat_masks(image)

    data_list = []
    # 一定のサイズのものだけを抽出する
    max_pixels = 13000
    min_pixels = 2300

    for index in range(sam.length):
        mask, pixels = sam.get(index)

        # 一定範囲のサイズのみ対象にする
        if pixels < min_pixels or max_pixels < pixels:
            continue

        # 輪郭検出用の2値のテンポラリ画像
        mono_image = np.full(np.array([height, width, 1]), 255, dtype=np.uint8)
        # 個々の食品を切取るためのテンポラリ画像
        food_image = image.copy()

        area = Area(width, height)
        for y in range(height):
            for x in range(width):
                if mask[y][x]:
                    mono_image[y][x] = 0  # 2値画像は、マスク部分を黒にする
                    area.set(x, y)
                else:
                    food_image[y][x] = [255, 255, 255]  # 食品切取り画像は、マスク部分以外を白にする

        # 検出範囲
        x1, y1, x2, y2 = area.get()
        # 食品の輪郭を取得する
        contours, _ = cv2.findContours(
            mono_image, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
        )
        # 食品画像を切取る
        food_image = food_image[y1:y2, x1:x2]
        # 分類モデルで食品名を推論する
        class_name, probs = ic.inference(food_image)

        # nodeで始まる名前は、食品では無い
        if class_name.startswith("none") == False:
            data_list.append(Data(food_image, contours, class_name, probs))
            print(
                "{} {:.2f}% [{},{},{},{}]".format(
                    get_name(class_name), probs, x1, y1, x2, y2
                )
            )

    # おすすめレシピ
    print("----------------------------------------------------------")
    food_list = []
    for data in data_list:
        food_list.append(get_name(data.class_name))

    gpt3.recipe_recommend(food_list)
    print("----------------------------------------------------------")

    # 元画像に輪郭を描画
    for data in data_list:
        cv2.drawContours(image, data.contours, -1, color=[0, 255, 255], thickness=3)
    plt.imshow(image)

    # 検出した食品画像を表示
    W = math.ceil(len(data_list) / 2)
    H = 2
    fig = plt.figure(figsize=(5, 5))

    for i, data in enumerate(data_list):
        ax1 = fig.add_subplot(H, W, i + 1)
        ax1.set_title(
            "{} {:.2f}".format(get_name(data.class_name), data.probs), fontsize=10
        )
        plt.imshow(data.image)

    plt.axis("off")
    plt.show()

    # 画面クリア
    os.system("clear")


def main():
    print("PyTorch version:", torch.__version__)
    print("CUDA is available:", torch.cuda.is_available())
    device = "cuda"

    # Resnet
    ic = ImageClassification(device, "model.pth", "label.txt")
    # Segment Anything
    sam = SegmentAnything(device, "vit_h", "sam_vit_h_4b8939.pth")
    # GPT3
    gpt3 = GPT3()

    file_list = [
        "image001.jpg",
        "image002.jpg",
        "image003.jpg",
        "image004.jpg",
        "image005.jpg",
        "image006.jpg",
    ]
    for file in file_list:
        job(file, sam, ic, gpt3)


if __name__ == "__main__":
    main()

この記事をシェアする

facebook logohatena logotwitter logo

© Classmethod, Inc. All rights reserved.