機械学習による言語分析 2


以前の投稿で決定木についてまとめた.

今回はその実装をテキストを分類することをテーマに行なってみたい.コードは下記.

# --- 1. ライブラリ ---
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import make_pipeline
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt

# --- 2. サンプル文 (15 文, 3 クラス) ---
sentences = [
    # sports
    "The striker scored a stunning goal in the final minute.",
    "Our team trained hard for the marathon this autumn.",
    "She broke the national record in the 100-metre sprint.",
    "The coach praised the goalkeeper's quick reflexes.",
    "Ticket sales soared after the club won the championship.",
    # technology
    "Quantum computing promises to revolutionise encryption.",
    "The new smartphone features an innovative folding screen.",
    "Researchers released an open-source large language model.",
    "Cloud latency was reduced by deploying edge servers.",
    "The startup secured funding for its AI-driven platform.",
    # food
    "The bakery's sourdough has a perfectly crisp crust.",
    "He paired the matured cheddar with a bold red wine.",
    "Street vendors served aromatic pho on every corner.",
    "Seasonal truffles elevated the pasta to fine-dining status.",
    "Cold-brew coffee tasted smooth with chocolate notes."
]
labels = [
    "sports","sports","sports","sports","sports",
    "tech","tech","tech","tech","tech",
    "food","food","food","food","food"
]

# --- 3. パイプライン (TF-IDF → Random Forest) ---
#    * 木構造を見やすくするため, 木を 1 本(max_depth=4) に制限
model = make_pipeline(
    TfidfVectorizer(ngram_range=(1,2), stop_words="english"),
    RandomForestClassifier(
        n_estimators=1,     # 1 本だけ生成
        max_depth=4,        # 深さを浅めに
        random_state=42
    )
)

# --- 4. 学習 ---
model.fit(sentences, labels)

# --- 5. 決定木の可視化 ---
# 学習済みモデルから単一の決定木を取得
tree = model.named_steps["randomforestclassifier"].estimators_[0]
vectorizer = model.named_steps["tfidfvectorizer"]

# class_names を全要素 str 型に変換して渡す
class_names_str = [str(c) for c in tree.classes_]

plt.figure(figsize=(24,10))
plot_tree(
    tree,
    feature_names=vectorizer.get_feature_names_out(),
    class_names=class_names_str,
    filled=True,
    rounded=True,
    fontsize=8
)
plt.title("Decision Tree Structure (from RandomForest, n_estimators=1)")
plt.show()

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です