871 lines
36 KiB
Python
871 lines
36 KiB
Python
"""
|
||
古诗词意境图生成器
|
||
将中国古典诗词通过 LLM 分析拆解为多个意境画面,
|
||
再使用 Z-Image-Turbo 本地模型逐一生成高质量图片。
|
||
"""
|
||
|
||
import argparse
|
||
import json
|
||
import os
|
||
import re
|
||
import shutil
|
||
import sys
|
||
import tempfile
|
||
import time
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
|
||
import torch
|
||
import yaml
|
||
from openai import OpenAI
|
||
from PIL import Image
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 设备检测与适配
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _init_xpu():
|
||
"""尝试初始化 Intel XPU 支持(需要 intel-extension-for-pytorch)。"""
|
||
try:
|
||
import intel_extension_for_pytorch as ipex # noqa: F401
|
||
return True
|
||
except ImportError:
|
||
return False
|
||
|
||
|
||
def resolve_device(configured_device: str) -> str:
|
||
"""根据配置和硬件可用性,决定实际使用的推理设备。
|
||
|
||
优先级: 用户配置 > auto 自动检测
|
||
auto 检测顺序: cuda > xpu > mps > cpu
|
||
"""
|
||
if configured_device == "auto":
|
||
if torch.cuda.is_available():
|
||
return "cuda"
|
||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||
_init_xpu()
|
||
return "xpu"
|
||
if hasattr(torch, "backends") and hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||
return "mps"
|
||
return "cpu"
|
||
|
||
if configured_device == "xpu":
|
||
if not (hasattr(torch, "xpu") and torch.xpu.is_available()):
|
||
print("警告: 配置了 xpu 设备但未检测到 Intel XPU,尝试初始化 IPEX...")
|
||
if not _init_xpu():
|
||
print("错误: 无法加载 intel-extension-for-pytorch,请确认已安装。")
|
||
print("安装命令: pip install intel-extension-for-pytorch")
|
||
sys.exit(1)
|
||
if not torch.xpu.is_available():
|
||
print("错误: IPEX 已加载但仍未检测到 XPU 设备。")
|
||
sys.exit(1)
|
||
else:
|
||
_init_xpu()
|
||
|
||
return configured_device
|
||
|
||
|
||
def get_supported_dtype(device: str, configured_dtype: str) -> torch.dtype:
|
||
"""根据设备返回合适的数据类型。
|
||
|
||
Intel Arc GPU 对 bfloat16 部分算子兼容性不佳,推荐使用 float16。
|
||
"""
|
||
dtype_map = {
|
||
"bfloat16": torch.bfloat16,
|
||
"float16": torch.float16,
|
||
"float32": torch.float32,
|
||
}
|
||
|
||
if configured_dtype == "auto":
|
||
if device == "xpu":
|
||
return torch.float16
|
||
if device in ("cuda", "mps"):
|
||
return torch.bfloat16
|
||
return torch.float32
|
||
|
||
dtype = dtype_map.get(configured_dtype, torch.bfloat16)
|
||
|
||
if device == "xpu" and dtype == torch.bfloat16:
|
||
print("提示: Intel Arc GPU 上 bfloat16 部分算子兼容性不佳,自动切换为 float16")
|
||
return torch.float16
|
||
|
||
return dtype
|
||
|
||
|
||
def create_generator(device: str, seed: int) -> torch.Generator:
|
||
"""为指定设备创建随机数生成器。"""
|
||
if device == "xpu":
|
||
return torch.Generator("xpu").manual_seed(seed)
|
||
if device == "cuda":
|
||
return torch.Generator("cuda").manual_seed(seed)
|
||
return torch.Generator().manual_seed(seed)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 配置加载
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def load_config(config_path: str = "config.yaml") -> dict:
|
||
with open(config_path, "r", encoding="utf-8") as f:
|
||
cfg = yaml.safe_load(f)
|
||
|
||
api_key = os.environ.get("LLM_API_KEY") or cfg["llm"].get("api_key", "")
|
||
cfg["llm"]["api_key"] = api_key
|
||
return cfg
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# LLM 古诗词分析
|
||
# ---------------------------------------------------------------------------
|
||
|
||
SYSTEM_PROMPT = """\
|
||
# Role
|
||
|
||
你是一位顶级的中国古典文学泰斗,同时也是精通 AI 图像生成(Text-to-Image,特别是 Z-Image-Turbo/Midjourney 等扩散模型)底层逻辑的顶级提示词工程师。
|
||
你对古诗词中的"意境"、"留白"、"虚实"有深刻理解,并能将这些抽象概念精准转化为扩散模型能够识别的**高权重视觉参数**(如明确的光影走向、材质纹理、构图视角、笔触细节)。
|
||
|
||
# Objective
|
||
|
||
接收用户输入的古诗词,严格按照"四段式思维链"转化为最高质量的图像生成提示词。你需要能够探索全诗连贯意象,将诗句转化为1张或多张分镜,并为每张分镜提供2种不同且极其细致的传统绘画风格提示词,确保画面不仅贴合诗意,而且具备极高的艺术审美与画面张力。
|
||
|
||
# Workflow(强制执行四段式思维链)
|
||
|
||
对于每一次输入,必须在内部严格执行以下步骤,结果最终输出为 JSON:
|
||
|
||
## 第一步:意境与分镜逻辑判断
|
||
- 若全诗时空统一,生成【单幅画面】。
|
||
- 若存在明显的视角切换(远近/高低)、时间推移(朝暮)或场景跳跃,拆分为 2-4 幅【分镜序列】。相邻且意境连贯的诗句应合并。
|
||
|
||
## 第二步:意境深度解析
|
||
分析每个分镜:核心情感基调、季节时间、天气状态、意境类型与情感张力。
|
||
|
||
## 第三步:现代文视觉脚本扩写(核心视觉转义)
|
||
将分镜扩写为极具画面感的视觉脚本,**必须将抽象词汇翻译为肉眼可见的物理细节**:
|
||
- **主体与动作**:人物姿态/服饰/微表情,核心景物的精确形态。
|
||
- **配景与层次**:前景、中景、远景的具体构成,建立空间纵深。
|
||
- **光线与色彩**:必须明确光源(如斜侧逆光、清冷月光、丁达尔效应/体积光)、色调(冷暖对比、低饱和度等)。
|
||
- **气候与动态**:风的方向、云雾的形态(流云/贴水薄雾)、水波的纹理(波光粼粼/惊涛骇浪)。
|
||
- **构图与尺度**:必须写明镜头视角(超大远景 / 黄金分割构图 / 仰视等),大远景必须加入尺度参照(远帆、飞鸟剪影、孤亭)以体现宏大感。
|
||
|
||
## 第四步:图像 Prompt 生成(双画风)
|
||
基于第三步的视觉脚本,为每一分镜设计 **2套** 彼此不同类别的中国传统画风(如:水墨写意 vs 青绿山水;工笔重彩 vs 浅绛山水)。
|
||
|
||
### 🎨 提示词(Prompt)构建法则(极其重要)
|
||
1. **中英文对应与结构**:`prompt` 和 `prompt_en` 必须是一段连续、自然流畅的描述(**绝对禁止出现[ ] 括号或标签名**)。
|
||
2. **英文生图语法强化 (`prompt_en`)**:英文提示词对模型影响最大,结构必须为:
|
||
`[画风约束词] +[画面主体与动作] + [环境与空间层次] + [光影与气候细节] +[笔触/色彩/媒介质感] + [顶级画质词]`。
|
||
3. **拒绝空洞抽象**:不要只写"sorrowful atmosphere"或"philosophical depth";必须用"withered lotus stalks bending in the cold wind, subdued blue-gray color palette"来表现抽象感。
|
||
4. **高质量风格约束词表(必须从以下挑选或组合并在 prompt 结尾处体现)**:
|
||
- **水墨写意**:Traditional Chinese ink wash painting, freehand brushwork (Xieyi), negative space, ethereal mist, varied ink tones, rhythmic brush strokes.
|
||
- **青绿山水**:Traditional Chinese blue-green landscape painting, mineral pigments, azurite and malachite tones, gold foil accents, majestic momentum.
|
||
- **工笔重彩**:Chinese meticulous heavy-color painting (Gongbi), rich saturated pigments, elaborate fine line drawing, opulent details, highly decorative.
|
||
- **浅绛山水**:Light crimson landscape painting, ochre wash, sparse and distant, elegant and refined, minimalist composition.
|
||
|
||
### 长度与质量标准
|
||
- `prompt`(中文):150 - 250 字。
|
||
- `prompt_en`(英文):100 - 200 词,多使用形容词+名词的词组(如 `volumetric lighting`, `cinematic lighting`, `intricate details`, `masterpiece, 8k resolution, best quality`)。
|
||
|
||
# Output Format (JSON Only)
|
||
|
||
严格输出以下 JSON 结构,不要包含任何多余解释:
|
||
|
||
```json
|
||
{
|
||
"title": "诗词标题",
|
||
"author": "作者",
|
||
"dynasty": "朝代",
|
||
"genre": "体裁",
|
||
"analysis": "时空统一,全诗描绘大漠孤烟的壮阔黄昏,故作为单幅画面。意境雄浑苍凉,核心情感是孤寂与壮美的交织。画面需重点表现极致的几何对比(直烟与圆日)和宏大的空间尺度。",
|
||
"images":[
|
||
{
|
||
"scene": "大漠孤烟直,长河落日圆。",
|
||
"description": "超大远景构图。前景是连绵起伏的金黄色沙丘,沙纹在夕阳斜照下呈现出明暗交界的锋利边缘。中景一条宽阔的河流蜿蜒折射着波光。视线中央,一道笔直的白色烽烟冲天而起,没有一丝风。远景的地平线上,一轮巨大、血红的落日正悬挂在长河尽头。冷暖色调形成强烈对比,画面极具几何雄浑之美,远空有几只渺小的飞鸟剪影作为尺度参照。",
|
||
"variants":[
|
||
{
|
||
"style": "浅绛山水",
|
||
"style_rationale": "浅绛山水的赭石淡彩能完美表现大漠黄昏的苍茫与孤寂感,线条萧疏清远。",
|
||
"prompt": "一幅传统的中国浅绛山水画,超大远景构图。画面中央是一片连绵起伏的沙丘,沙纹细腻,远方一条宽阔的长河蜿蜒流淌。长河尽头的地平线上悬挂着一轮巨大的血红色落日。一道笔直的白色烽烟从烽火台冲天而起,直入云霄。天空中点缀着几只微小的飞鸟剪影,凸显出大漠的浩瀚无垠。画面采用赭石淡彩着色,夕阳的余晖给沙丘和河面染上一层凄美的暖光。留白与虚实相生,意境苍凉雄浑。杰作,8k分辨率,极致细节,电影级光影,最高画质。",
|
||
"prompt_en": "A masterpiece of traditional Chinese light crimson landscape painting (Qianjiang), ultra-wide panoramic shot. Endless rolling sand dunes with delicate ripples in the foreground. A wide, majestic river winds its way through the vast desert. At the distant horizon of the river, a giant, blood-red setting sun hangs low. A single, perfectly straight column of white smoke rises directly into the sky from an ancient beacon tower. Tiny silhouettes of flying birds in the vast sky provide a sense of grand scale. Colored with subtle ochre wash and pale warm tones. The golden hour lighting casts long dramatic shadows on the dunes. Ethereal atmosphere, negative space, sparse and distant, traditional Chinese brushwork, masterpiece, 8k resolution, highly detailed, cinematic lighting, breathtaking scenery."
|
||
},
|
||
{
|
||
"style": "泼墨大写意",
|
||
"style_rationale": "通过墨色的酣畅淋漓与狂放笔触,强化沙漠与落日之间的磅礴气势与浑厚张力。",
|
||
"prompt": "一幅气势磅礴的中国泼墨大写意画。用浓淡相宜的泼墨挥洒出连绵不绝的苍茫大漠与雄浑山势,笔触狂放且充满力量。一条留白形成的长河贯穿画面,河面波光隐约。长河尽头,用朱砂重彩点染出一轮巨大而耀眼的落日,与周围的黑白墨色形成极具视觉冲击力的红黑对比。一道用枯笔飞白表现的笔直烽烟直刺苍穹。画面充满墨色淋漓的律动感,光影粗犷,意境苍茫悲壮。杰作,8k画质,令人惊叹的笔触细节,艺术珍品。",
|
||
"prompt_en": "A majestic traditional Chinese splash ink painting (Da Xieyi), majestic momentum. Bold, expressive, and sweeping ink brushstrokes create the vast, endless desert landscape and rugged terrain. A wide river is formed by masterful use of negative space, flowing through the center. At the end of the river, a massive, vibrant vermilion red setting sun is painted with heavy pigments, creating a striking contrast against the monochromatic black and gray ink wash. A straight column of smoke rises to the sky, rendered with dry brush techniques (Feibai). Dynamic ink splashes, rhythmic brushstrokes, bold black-and-red color contrast, atmospheric and dramatic lighting, masterpiece, 8k resolution, highly detailed, traditional Chinese art museum quality."
|
||
}
|
||
]
|
||
}
|
||
]
|
||
}\
|
||
"""
|
||
|
||
|
||
def _build_user_message(poem: str, cfg: dict) -> str:
|
||
"""构造发送给 LLM 的用户消息,包含诗词和可选的风格期望。"""
|
||
style_pref = cfg["image"].get("style_preference", "").strip()
|
||
if style_pref:
|
||
style_line = f"【风格期望】:{style_pref}"
|
||
else:
|
||
style_line = "【风格期望】:默认(根据诗意自动选择最契合的传统画风)"
|
||
|
||
return (
|
||
f"请为以下古诗词生成图像提示词:\n\n"
|
||
f"【输入诗词】:\n{poem}\n\n"
|
||
f"{style_line}\n\n"
|
||
f"请严格按照 System Prompt 的要求,首先进行【意境与分镜逻辑判断】,"
|
||
f"随后针对单幅或多幅分镜依次输出对应的【意境深度解析】、"
|
||
f"【现代文视觉转义】,并为**每一幅分镜**输出 **2 套**不同传统画风的【图像生成 Prompt】(`variants` 数组)。"
|
||
)
|
||
|
||
|
||
def analyze_poetry(poem: str, cfg: dict) -> dict:
|
||
"""调用 LLM 分析古诗词,返回结构化的图片生成方案。"""
|
||
llm_cfg = cfg["llm"]
|
||
|
||
client = OpenAI(
|
||
base_url=llm_cfg["base_url"],
|
||
api_key=llm_cfg["api_key"],
|
||
timeout=120,
|
||
)
|
||
|
||
style_pref = cfg["image"].get("style_preference", "").strip()
|
||
print(f"\n{'='*60}")
|
||
print("正在调用 LLM 分析古诗词意境(四段式思维链)...")
|
||
print(f"模型: {llm_cfg['model']}")
|
||
if style_pref:
|
||
print(f"风格期望: {style_pref}")
|
||
print(f"{'='*60}\n")
|
||
|
||
user_message = _build_user_message(poem, cfg)
|
||
|
||
response = client.chat.completions.create(
|
||
model=llm_cfg["model"],
|
||
temperature=llm_cfg.get("temperature", 0.7),
|
||
max_tokens=llm_cfg.get("max_tokens", 4096),
|
||
messages=[
|
||
{"role": "system", "content": SYSTEM_PROMPT},
|
||
{"role": "user", "content": user_message},
|
||
],
|
||
)
|
||
|
||
content = response.choices[0].message.content.strip()
|
||
|
||
json_match = re.search(r"```(?:json)?\s*(.*?)```", content, re.DOTALL)
|
||
if json_match:
|
||
content = json_match.group(1).strip()
|
||
|
||
try:
|
||
result = json.loads(content)
|
||
except json.JSONDecodeError:
|
||
json_match = re.search(r"\{.*\}", content, re.DOTALL)
|
||
if json_match:
|
||
result = json.loads(json_match.group())
|
||
else:
|
||
print("LLM 返回内容无法解析为 JSON:")
|
||
print(content)
|
||
sys.exit(1)
|
||
|
||
return result
|
||
|
||
|
||
def _normalize_scene_variants(img_info: dict, max_variants: int) -> list[tuple[str, dict]]:
|
||
"""从单条分镜解析待生成的画风变体,供绘图循环使用。
|
||
|
||
返回 [(文件名标签如 v01, variant 字典), ...]。
|
||
兼容旧版 JSON(无 variants 数组时退回顶层 prompt / prompt_en)。
|
||
"""
|
||
max_variants = max(1, min(2, int(max_variants)))
|
||
raw = img_info.get("variants")
|
||
collected: list[dict] = []
|
||
if isinstance(raw, list):
|
||
for v in raw:
|
||
if isinstance(v, dict) and (v.get("prompt") or v.get("prompt_en")):
|
||
collected.append(v)
|
||
if collected:
|
||
return [(f"v{idx:02d}", collected[idx - 1]) for idx in range(1, min(len(collected), max_variants) + 1)]
|
||
|
||
legacy = {
|
||
"style": img_info.get("style", ""),
|
||
"style_rationale": "",
|
||
"prompt": img_info.get("prompt", ""),
|
||
"prompt_en": img_info.get("prompt_en", ""),
|
||
}
|
||
if legacy["prompt"] or legacy["prompt_en"]:
|
||
return [("v01", legacy)]
|
||
return []
|
||
|
||
|
||
def display_analysis(analysis: dict) -> None:
|
||
"""友好地展示 LLM 的分析结果。"""
|
||
print(f"\n{'='*60}")
|
||
title = analysis.get("title", "未知")
|
||
author = analysis.get("author", "未知")
|
||
dynasty = analysis.get("dynasty", "")
|
||
genre = analysis.get("genre", "")
|
||
print(f"📜 {title} — {dynasty} · {author} [{genre}]")
|
||
print(f"{'='*60}")
|
||
print(f"\n🔍 意境分析:{analysis.get('analysis', '')}\n")
|
||
|
||
for i, img in enumerate(analysis["images"], 1):
|
||
print(f"{'─'*50}")
|
||
print(f"🖼 第 {i} 幅 | {img['scene']}")
|
||
desc = img.get("description", "")
|
||
if desc:
|
||
print(f" 中文描述:{desc}")
|
||
vlist = img.get("variants")
|
||
if isinstance(vlist, list) and vlist:
|
||
for vi, v in enumerate(vlist, 1):
|
||
print(f" ─ 画风方案 {vi}:{v.get('style', '未指定')}")
|
||
if v.get("style_rationale"):
|
||
print(f" 说明:{v['style_rationale']}")
|
||
zh = v.get("prompt") or ""
|
||
if zh:
|
||
tail = "..." if len(zh) > 120 else ""
|
||
print(f" Prompt(zh):{zh[:120]}{tail}")
|
||
en = v.get("prompt_en") or ""
|
||
if en:
|
||
tail = "..." if len(en) > 120 else ""
|
||
print(f" Prompt(en):{en[:120]}{tail}")
|
||
else:
|
||
print(f" 画风选择:{img.get('style', '未指定')}")
|
||
zh = img.get("prompt") or ""
|
||
if zh:
|
||
print(f" Prompt(zh):{zh[:120]}..." if len(zh) > 120 else f" Prompt(zh):{zh}")
|
||
if img.get("prompt_en"):
|
||
en = img["prompt_en"]
|
||
print(f" Prompt(en):{en[:120]}..." if len(en) > 120 else f" Prompt(en):{en}")
|
||
|
||
n_scenes = len(analysis["images"])
|
||
n_variants = sum(
|
||
len(img["variants"]) if isinstance(img.get("variants"), list) else (1 if img.get("prompt") or img.get("prompt_en") else 0)
|
||
for img in analysis["images"]
|
||
)
|
||
print(f"\n共 {n_scenes} 个分镜;LLM 共给出约 {n_variants} 套画风方案(生成张数受配置 style_variants 与 images_per_prompt 影响)\n")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 尺寸预设
|
||
# ---------------------------------------------------------------------------
|
||
|
||
SIZE_PRESETS: dict[str, tuple[int, int]] = {
|
||
"square": (1024, 1024),
|
||
"phone": ( 576, 1024),
|
||
"phone_hd": ( 768, 1344),
|
||
"desktop": (1024, 576),
|
||
"desktop_hd": (1344, 768),
|
||
"ultrawide": (1536, 640),
|
||
}
|
||
|
||
|
||
def resolve_image_size(img_cfg: dict) -> tuple[int, int]:
|
||
"""根据 size_preset 或 height/width 配置,返回 (width, height)。"""
|
||
preset = img_cfg.get("size_preset", "").strip().lower()
|
||
if preset and preset != "custom" and preset in SIZE_PRESETS:
|
||
w, h = SIZE_PRESETS[preset]
|
||
return w, h
|
||
return img_cfg.get("width", 1024), img_cfg.get("height", 1024)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Z-Image-Turbo 本地图片生成
|
||
# ---------------------------------------------------------------------------
|
||
|
||
HF_REPO = "Tongyi-MAI/Z-Image-Turbo"
|
||
|
||
# 从 HuggingFace 仓库下载的小型配置文件(首次需要网络,之后自动缓存)
|
||
_HF_CONFIG_FILES = [
|
||
"model_index.json",
|
||
"scheduler/scheduler_config.json",
|
||
"tokenizer/merges.txt",
|
||
"tokenizer/tokenizer_config.json",
|
||
"tokenizer/vocab.json",
|
||
"text_encoder/config.json",
|
||
"text_encoder/generation_config.json",
|
||
"transformer/config.json",
|
||
"vae/config.json",
|
||
]
|
||
|
||
|
||
def _force_link(src: Path, dst: Path) -> None:
|
||
"""创建从 dst 指向 src 的链接,兼容 Windows 无管理员权限的场景。
|
||
|
||
优先级: 符号链接 → 硬链接 → 复制文件
|
||
- 符号链接在 Windows 下需要管理员权限或开启开发者模式
|
||
- 硬链接无需特权但要求 src 和 dst 在同一驱动器
|
||
- 以上均失败时回退到复制(大文件会较慢,但保证可用)
|
||
"""
|
||
src = Path(src).resolve()
|
||
dst = Path(dst)
|
||
if dst.exists() or dst.is_symlink():
|
||
dst.unlink()
|
||
|
||
# 1. 尝试符号链接
|
||
try:
|
||
dst.symlink_to(src)
|
||
return
|
||
except OSError:
|
||
pass
|
||
|
||
# 2. 尝试硬链接(要求同一驱动器/文件系统)
|
||
try:
|
||
os.link(str(src), str(dst))
|
||
return
|
||
except OSError:
|
||
pass
|
||
|
||
# 3. 回退到复制
|
||
print(f" 提示: 无法创建链接,正在复制文件: {src.name}(可能较慢)")
|
||
shutil.copy2(str(src), str(dst))
|
||
|
||
|
||
def _is_comfyui_mode(cfg: dict) -> bool:
|
||
"""判断是否配置了 ComfyUI 拆分文件模式。"""
|
||
comfyui = cfg["image"].get("comfyui", {})
|
||
return bool(
|
||
comfyui.get("text_encoder")
|
||
and comfyui.get("transformer")
|
||
and comfyui.get("vae")
|
||
)
|
||
|
||
|
||
def _is_openvino_mode(cfg: dict) -> bool:
|
||
"""判断是否配置了 OpenVINO 推理模式。"""
|
||
return bool(cfg["image"].get("openvino", {}).get("model_path"))
|
||
|
||
|
||
def _load_pipeline_openvino(cfg: dict):
|
||
"""使用 OpenVINO 加载 Z-Image-Turbo pipeline。
|
||
|
||
需要预先通过 optimum-cli 导出 OpenVINO IR 模型:
|
||
optimum-cli export openvino --model Tongyi-MAI/Z-Image-Turbo \\
|
||
--weight-format int8 z-image-turbo-ov
|
||
"""
|
||
from optimum.intel import OVZImagePipeline
|
||
|
||
ov_cfg = cfg["image"]["openvino"]
|
||
model_path = ov_cfg["model_path"]
|
||
ov_device = ov_cfg.get("device", "GPU")
|
||
|
||
print(f"模式: OpenVINO 推理")
|
||
print(f" 模型路径 : {model_path}")
|
||
print(f" OV 设备 : {ov_device}")
|
||
|
||
if not Path(model_path).exists():
|
||
print(f"错误: OpenVINO 模型目录不存在: {model_path}")
|
||
print("请先使用 optimum-cli 导出模型:")
|
||
print(f" optimum-cli export openvino --model {HF_REPO} --weight-format int8 {model_path}")
|
||
sys.exit(1)
|
||
|
||
pipe = OVZImagePipeline.from_pretrained(model_path, device=ov_device)
|
||
return pipe
|
||
|
||
|
||
def _build_hf_layout_from_comfyui(cfg: dict) -> str:
|
||
"""从 ComfyUI 拆分文件构建 HuggingFace 兼容的目录布局。
|
||
|
||
原理:下载 HuggingFace 仓库中的微型配置文件(JSON/txt,共计 < 100KB),
|
||
然后创建指向 ComfyUI 权重文件的符号链接,最终得到一个
|
||
`ZImagePipeline.from_pretrained()` 可直接加载的目录。
|
||
"""
|
||
from huggingface_hub import hf_hub_download
|
||
|
||
comfyui = cfg["image"]["comfyui"]
|
||
te_path = Path(comfyui["text_encoder"]).resolve()
|
||
tf_path = Path(comfyui["transformer"]).resolve()
|
||
vae_path = Path(comfyui["vae"]).resolve()
|
||
|
||
for name, p in [("text_encoder", te_path), ("transformer", tf_path), ("vae", vae_path)]:
|
||
if not p.exists():
|
||
print(f"错误: ComfyUI {name} 文件不存在: {p}")
|
||
sys.exit(1)
|
||
|
||
cache_dir = Path(".cache") / "comfyui_hf_layout"
|
||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
print("正在准备 HuggingFace 兼容目录结构(仅首次需下载配置文件)...")
|
||
|
||
for rel_path in _HF_CONFIG_FILES:
|
||
dest = cache_dir / rel_path
|
||
if not dest.exists():
|
||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||
src = hf_hub_download(HF_REPO, rel_path)
|
||
shutil.copy2(src, dest)
|
||
|
||
# 链接权重文件 —— text_encoder
|
||
te_link = cache_dir / "text_encoder" / "model.safetensors"
|
||
_force_link(te_path, te_link)
|
||
# 删除分片索引(如果存在),因为 ComfyUI 的文件是单一非分片文件
|
||
shard_idx = cache_dir / "text_encoder" / "model.safetensors.index.json"
|
||
if shard_idx.exists():
|
||
shard_idx.unlink()
|
||
|
||
# 链接权重文件 —— transformer
|
||
tf_link = cache_dir / "transformer" / "diffusion_pytorch_model.safetensors"
|
||
_force_link(tf_path, tf_link)
|
||
|
||
# 链接权重文件 —— vae
|
||
vae_link = cache_dir / "vae" / "diffusion_pytorch_model.safetensors"
|
||
_force_link(vae_path, vae_link)
|
||
|
||
print(f"目录结构已就绪: {cache_dir}")
|
||
return str(cache_dir)
|
||
|
||
|
||
|
||
def _load_pipeline_comfyui(cfg: dict, device: str, torch_dtype: torch.dtype):
|
||
"""从 ComfyUI 拆分文件加载 pipeline(使用逐组件加载方式,仅支持 safetensors)。"""
|
||
from diffusers import (
|
||
AutoencoderKL,
|
||
FlowMatchEulerDiscreteScheduler,
|
||
ZImagePipeline,
|
||
ZImageTransformer2DModel,
|
||
)
|
||
from transformers import AutoTokenizer, Qwen3Model
|
||
|
||
comfyui = cfg["image"]["comfyui"]
|
||
te_path = comfyui["text_encoder"]
|
||
tf_path = comfyui["transformer"]
|
||
vae_path = comfyui["vae"]
|
||
|
||
print("模式: ComfyUI 拆分文件加载")
|
||
print(f" Text Encoder : {te_path}")
|
||
print(f" Transformer : {tf_path}")
|
||
print(f" VAE : {vae_path}")
|
||
|
||
print(" 加载 Scheduler & Tokenizer(配置来自 HuggingFace 缓存)...")
|
||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(HF_REPO, subfolder="scheduler")
|
||
tokenizer = AutoTokenizer.from_pretrained(HF_REPO, subfolder="tokenizer")
|
||
|
||
print(" 加载 Transformer...")
|
||
transformer = ZImageTransformer2DModel.from_single_file(
|
||
tf_path,
|
||
config=HF_REPO,
|
||
subfolder="transformer",
|
||
torch_dtype=torch_dtype,
|
||
)
|
||
|
||
print(" 加载 VAE...")
|
||
vae = AutoencoderKL.from_single_file(
|
||
vae_path,
|
||
config=HF_REPO,
|
||
subfolder="vae",
|
||
torch_dtype=torch_dtype,
|
||
)
|
||
|
||
print(" 加载 Text Encoder (Qwen3 4B)...")
|
||
te_config_path = hf_hub_download_cached(HF_REPO, "text_encoder/config.json")
|
||
te_gen_config_path = hf_hub_download_cached(HF_REPO, "text_encoder/generation_config.json")
|
||
|
||
te_parent_dir = str(Path(te_path).resolve().parent)
|
||
with tempfile.TemporaryDirectory(dir=te_parent_dir) as tmpdir:
|
||
shutil.copy2(te_config_path, os.path.join(tmpdir, "config.json"))
|
||
shutil.copy2(te_gen_config_path, os.path.join(tmpdir, "generation_config.json"))
|
||
_force_link(Path(te_path), Path(tmpdir) / "model.safetensors")
|
||
text_encoder = Qwen3Model.from_pretrained(tmpdir, torch_dtype=torch_dtype)
|
||
|
||
pipe = ZImagePipeline(
|
||
scheduler=scheduler,
|
||
vae=vae,
|
||
text_encoder=text_encoder,
|
||
tokenizer=tokenizer,
|
||
transformer=transformer,
|
||
)
|
||
return pipe
|
||
|
||
|
||
def hf_hub_download_cached(repo_id: str, filename: str) -> str:
|
||
"""下载 HuggingFace 仓库中的文件(自动缓存)。"""
|
||
from huggingface_hub import hf_hub_download
|
||
return hf_hub_download(repo_id, filename)
|
||
|
||
|
||
def load_pipeline(cfg: dict):
|
||
"""加载 Z-Image-Turbo pipeline。自动适配 OpenVINO / HuggingFace / ComfyUI 格式。"""
|
||
img_cfg = cfg["image"]
|
||
|
||
# OpenVINO 模式:由 optimum.intel 管理设备,无需手动 resolve_device
|
||
if _is_openvino_mode(cfg):
|
||
ov_device = img_cfg["openvino"].get("device", "GPU")
|
||
print(f"\n{'='*60}")
|
||
print("正在加载 Z-Image-Turbo 模型 (OpenVINO)...")
|
||
print(f"OpenVINO 设备: {ov_device}")
|
||
print(f"{'='*60}\n")
|
||
|
||
pipe = _load_pipeline_openvino(cfg)
|
||
cfg["_resolved_device"] = "cpu"
|
||
cfg["_openvino_mode"] = True
|
||
return pipe
|
||
|
||
from diffusers import ZImagePipeline
|
||
|
||
device = resolve_device(img_cfg.get("device", "auto"))
|
||
torch_dtype = get_supported_dtype(device, img_cfg.get("torch_dtype", "auto"))
|
||
|
||
print(f"\n{'='*60}")
|
||
print("正在加载 Z-Image-Turbo 模型...")
|
||
print(f"推理设备: {device}")
|
||
print(f"数据类型: {torch_dtype}")
|
||
print(f"{'='*60}\n")
|
||
|
||
if _is_comfyui_mode(cfg):
|
||
pipe = _load_pipeline_comfyui(cfg, device, torch_dtype)
|
||
else:
|
||
model_id = img_cfg["model_id"]
|
||
print(f"模式: HuggingFace 标准加载")
|
||
print(f"模型路径: {model_id}")
|
||
pipe = ZImagePipeline.from_pretrained(
|
||
model_id,
|
||
torch_dtype=torch_dtype,
|
||
low_cpu_mem_usage=False,
|
||
)
|
||
|
||
raw_offload = str(img_cfg.get("enable_cpu_offload", "false")).strip().lower()
|
||
offload_mode = {
|
||
"false": None, "0": None, "no": None, "off": None,
|
||
"true": "model", "1": "model", "yes": "model", "on": "model",
|
||
"model": "model",
|
||
"sequential": "sequential",
|
||
}.get(raw_offload, None)
|
||
|
||
if offload_mode and device in ("cuda", "xpu", "mps"):
|
||
if offload_mode == "sequential":
|
||
print("启用 Sequential CPU Offload: 逐层搬入显卡,最省显存但较慢")
|
||
pipe.enable_sequential_cpu_offload(device=device)
|
||
else:
|
||
print("启用 Model CPU Offload: 组件级按需加载到显卡")
|
||
pipe.enable_model_cpu_offload(device=device)
|
||
else:
|
||
if device != "cpu" and not offload_mode:
|
||
print("提示: 所有模型将同时加载到显卡,如显存不足请在配置中开启 enable_cpu_offload")
|
||
pipe.to(device)
|
||
|
||
if device not in ("xpu", "cpu"):
|
||
attn_backend = img_cfg.get("attention_backend", "sdpa")
|
||
if attn_backend == "flash":
|
||
pipe.transformer.set_attention_backend("flash")
|
||
elif attn_backend == "flash_3":
|
||
pipe.transformer.set_attention_backend("_flash_3")
|
||
|
||
lora_cfg = cfg.get("lora", {})
|
||
if lora_cfg.get("enabled") and lora_cfg.get("path"):
|
||
lora_path = lora_cfg["path"]
|
||
lora_weight = lora_cfg.get("weight", 0.8)
|
||
print(f"正在加载 LoRA: {lora_path} (权重: {lora_weight})")
|
||
pipe.load_lora_weights(lora_path)
|
||
pipe.fuse_lora(lora_scale=lora_weight)
|
||
print("LoRA 加载完成")
|
||
|
||
cfg["_resolved_device"] = device
|
||
cfg["_openvino_mode"] = False
|
||
return pipe
|
||
|
||
|
||
def generate_images(pipe, analysis: dict, cfg: dict) -> list[Path]:
|
||
"""根据分析结果逐一生成图片,返回保存路径列表。"""
|
||
img_cfg = cfg["image"]
|
||
out_cfg = cfg["output"]
|
||
lora_cfg = cfg.get("lora", {})
|
||
device = cfg.get("_resolved_device", "cpu")
|
||
|
||
output_dir = Path(out_cfg.get("dir", "./output"))
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
prefix = out_cfg.get("filename_prefix", "poem")
|
||
width, height = resolve_image_size(img_cfg)
|
||
steps = img_cfg.get("num_inference_steps", 9)
|
||
guidance = img_cfg.get("guidance_scale", 0.0)
|
||
seed = img_cfg.get("seed", -1)
|
||
|
||
trigger_words = ""
|
||
if lora_cfg.get("enabled") and lora_cfg.get("trigger_words"):
|
||
trigger_words = lora_cfg["trigger_words"].strip()
|
||
|
||
preset = img_cfg.get("size_preset", "custom")
|
||
prompt_lang = img_cfg.get("prompt_language", "zh")
|
||
images_per_prompt = max(1, min(10, img_cfg.get("images_per_prompt", 1)))
|
||
max_style_variants = max(1, min(2, int(img_cfg.get("style_variants", 2))))
|
||
print(f"图片尺寸: {width}×{height}" + (f" (预设: {preset})" if preset != "custom" else ""))
|
||
print(f"Prompt 语言: {prompt_lang}")
|
||
print(f"每分镜画风方案数: {max_style_variants}(配置项 style_variants,1 或 2)")
|
||
if images_per_prompt > 1:
|
||
print(f"每个 prompt 生成 {images_per_prompt} 张图(不同种子)")
|
||
|
||
saved_paths = []
|
||
total = len(analysis["images"])
|
||
|
||
for i, img_info in enumerate(analysis["images"], 1):
|
||
variant_list = _normalize_scene_variants(img_info, max_style_variants)
|
||
if not variant_list:
|
||
print(f"\n警告: 第 {i}/{total} 幅分镜无有效 prompt,已跳过: {img_info.get('scene', '')}")
|
||
continue
|
||
|
||
print(f"\n[{i}/{total}] 分镜: {img_info['scene']}")
|
||
|
||
if out_cfg.get("save_prompts", True):
|
||
txt_lines = [
|
||
f"Scene: {img_info['scene']}\n",
|
||
f"Description: {img_info.get('description', '')}\n",
|
||
f"Prompt_language_used: {prompt_lang}\n",
|
||
]
|
||
|
||
for vi, (v_label, variant) in enumerate(variant_list):
|
||
if prompt_lang == "en" and variant.get("prompt_en"):
|
||
prompt = variant["prompt_en"]
|
||
else:
|
||
prompt = variant.get("prompt") or variant.get("prompt_en") or ""
|
||
if trigger_words:
|
||
prompt = f"{trigger_words}, {prompt}"
|
||
|
||
st = variant.get("style", "未指定")
|
||
print(f" [{v_label}] 画风: {st}")
|
||
prev = prompt[:120] + ("..." if len(prompt) > 120 else "")
|
||
print(f" Prompt({prompt_lang}): {prev}")
|
||
|
||
if out_cfg.get("save_prompts", True):
|
||
txt_lines.append(f"\n--- {v_label} | Style: {st} ---\n")
|
||
if variant.get("style_rationale"):
|
||
txt_lines.append(f"Rationale: {variant['style_rationale']}\n")
|
||
txt_lines.append(f"Prompt(zh): {variant.get('prompt', '')}\n")
|
||
txt_lines.append(f"Prompt(en): {variant.get('prompt_en', '')}\n")
|
||
txt_lines.append(f"Used({prompt_lang}): {prompt}\n")
|
||
|
||
for j in range(images_per_prompt):
|
||
variant_offset = i * 100 + vi * 17 + j
|
||
actual_seed = (seed + variant_offset) if seed >= 0 else (int(time.time() * 1000) % (2**32) + variant_offset)
|
||
generator = create_generator(device, actual_seed)
|
||
|
||
seed_suffix = chr(ord("a") + j) if images_per_prompt > 1 else ""
|
||
if images_per_prompt > 1:
|
||
print(f" --- 同画风第 {j+1}/{images_per_prompt} 张 (seed={actual_seed}) ---")
|
||
|
||
start_time = time.time()
|
||
|
||
result = pipe(
|
||
prompt=prompt,
|
||
height=height,
|
||
width=width,
|
||
num_inference_steps=steps,
|
||
guidance_scale=guidance,
|
||
generator=generator,
|
||
)
|
||
image: Image.Image = result.images[0]
|
||
|
||
elapsed = time.time() - start_time
|
||
print(f" 生成完成,耗时 {elapsed:.1f}s")
|
||
|
||
img_path = output_dir / f"{prefix}_{i:02d}_{v_label}{seed_suffix}.png"
|
||
image.save(img_path)
|
||
saved_paths.append(img_path)
|
||
print(f" 已保存: {img_path}")
|
||
|
||
if out_cfg.get("save_prompts", True):
|
||
txt_path = output_dir / f"{prefix}_{i:02d}_prompt.txt"
|
||
txt_path.write_text("".join(txt_lines), encoding="utf-8")
|
||
|
||
return saved_paths
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 主流程
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(
|
||
description="古诗词意境图生成器 — 基于 LLM 分析 + Z-Image-Turbo 生成"
|
||
)
|
||
parser.add_argument(
|
||
"-c", "--config",
|
||
default="config.yaml",
|
||
help="配置文件路径(默认: config.yaml)",
|
||
)
|
||
parser.add_argument(
|
||
"-p", "--poem",
|
||
type=str,
|
||
default=None,
|
||
help="直接传入古诗词文本(如不指定则交互式输入)",
|
||
)
|
||
parser.add_argument(
|
||
"--analyze-only",
|
||
action="store_true",
|
||
help="仅进行 LLM 分析,不生成图片",
|
||
)
|
||
parser.add_argument(
|
||
"-o", "--output",
|
||
type=str,
|
||
default=None,
|
||
help="覆盖输出目录",
|
||
)
|
||
args = parser.parse_args()
|
||
|
||
cfg = load_config(args.config)
|
||
|
||
if args.output:
|
||
cfg["output"]["dir"] = args.output
|
||
else:
|
||
now = datetime.now()
|
||
date_dir = now.strftime("%Y-%m-%d")
|
||
time_dir = now.strftime("%H-%M-%S")
|
||
cfg["output"]["dir"] = str(Path(cfg["output"].get("dir", "./output")) / date_dir / time_dir)
|
||
|
||
if args.poem:
|
||
poem = args.poem
|
||
else:
|
||
print("请输入古诗词(输入空行结束):")
|
||
lines = []
|
||
while True:
|
||
line = input()
|
||
if line.strip() == "":
|
||
break
|
||
lines.append(line)
|
||
poem = "\n".join(lines)
|
||
|
||
if not poem.strip():
|
||
print("未输入任何内容,退出。")
|
||
sys.exit(0)
|
||
|
||
print(f"\n📝 输入的诗词:\n{poem}")
|
||
|
||
analysis = analyze_poetry(poem, cfg)
|
||
display_analysis(analysis)
|
||
|
||
output_dir = Path(cfg["output"].get("dir", "./output"))
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
analysis_path = output_dir / "analysis.json"
|
||
analysis_path.write_text(
|
||
json.dumps(analysis, ensure_ascii=False, indent=2),
|
||
encoding="utf-8",
|
||
)
|
||
print(f"分析结果已保存: {analysis_path}")
|
||
|
||
if args.analyze_only:
|
||
print("\n已完成分析(--analyze-only 模式),跳过图片生成。")
|
||
return
|
||
|
||
if _is_openvino_mode(cfg):
|
||
ov_device = cfg["image"]["openvino"].get("device", "GPU")
|
||
print(f"\n🖥 推理模式: OpenVINO ({ov_device})")
|
||
if ov_device.upper() == "GPU" and hasattr(torch, "xpu") and torch.xpu.is_available():
|
||
print(f" Intel XPU: {torch.xpu.get_device_name(0)}")
|
||
else:
|
||
device = resolve_device(cfg["image"].get("device", "auto"))
|
||
print(f"\n🖥 推理设备: {device}")
|
||
if device == "xpu":
|
||
print(f" Intel XPU: {torch.xpu.get_device_name(0)}")
|
||
print(f" 显存: {torch.xpu.get_device_properties(0).total_memory / 1024**3:.1f} GB")
|
||
elif device == "cuda":
|
||
print(f" CUDA GPU: {torch.cuda.get_device_name(0)}")
|
||
print(f" 显存: {torch.cuda.get_device_properties(0).total_mem / 1024**3:.1f} GB")
|
||
|
||
pipe = load_pipeline(cfg)
|
||
saved = generate_images(pipe, analysis, cfg)
|
||
|
||
print(f"\n{'='*60}")
|
||
print(f"全部完成!共生成 {len(saved)} 幅图片:")
|
||
for p in saved:
|
||
print(f" 📁 {p}")
|
||
print(f"{'='*60}\n")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|