機械学習による言語分析 2
以前の投稿で決定木についてまとめた.
今回はその実装をテキストを分類することをテーマに行なってみたい.コードは下記.
# --- 1. ライブラリ ---
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import make_pipeline
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
# --- 2. サンプル文 (15 文, 3 クラス) ---
sentences = [
# sports
"The striker scored a stunning goal in the final minute.",
"Our team trained hard for the marathon this autumn.",
"She broke the national record in the 100-metre sprint.",
"The coach praised the goalkeeper's quick reflexes.",
"Ticket sales soared after the club won the championship.",
# technology
"Quantum computing promises to revolutionise encryption.",
"The new smartphone features an innovative folding screen.",
"Researchers released an open-source large language model.",
"Cloud latency was reduced by deploying edge servers.",
"The startup secured funding for its AI-driven platform.",
# food
"The bakery's sourdough has a perfectly crisp crust.",
"He paired the matured cheddar with a bold red wine.",
"Street vendors served aromatic pho on every corner.",
"Seasonal truffles elevated the pasta to fine-dining status.",
"Cold-brew coffee tasted smooth with chocolate notes."
]
labels = [
"sports","sports","sports","sports","sports",
"tech","tech","tech","tech","tech",
"food","food","food","food","food"
]
# --- 3. パイプライン (TF-IDF → Random Forest) ---
# * 木構造を見やすくするため, 木を 1 本(max_depth=4) に制限
model = make_pipeline(
TfidfVectorizer(ngram_range=(1,2), stop_words="english"),
RandomForestClassifier(
n_estimators=1, # 1 本だけ生成
max_depth=4, # 深さを浅めに
random_state=42
)
)
# --- 4. 学習 ---
model.fit(sentences, labels)
# --- 5. 決定木の可視化 ---
# 学習済みモデルから単一の決定木を取得
tree = model.named_steps["randomforestclassifier"].estimators_[0]
vectorizer = model.named_steps["tfidfvectorizer"]
# class_names を全要素 str 型に変換して渡す
class_names_str = [str(c) for c in tree.classes_]
plt.figure(figsize=(24,10))
plot_tree(
tree,
feature_names=vectorizer.get_feature_names_out(),
class_names=class_names_str,
filled=True,
rounded=True,
fontsize=8
)
plt.title("Decision Tree Structure (from RandomForest, n_estimators=1)")
plt.show()
