855 lines
32 KiB
Python
855 lines
32 KiB
Python
"""
|
||
古诗词意境图生成器
|
||
将中国古典诗词通过 LLM 分析拆解为多个意境画面,
|
||
再使用 Z-Image-Turbo 本地模型逐一生成高质量图片。
|
||
"""
|
||
|
||
import argparse
|
||
import json
|
||
import os
|
||
import re
|
||
import shutil
|
||
import sys
|
||
import tempfile
|
||
import time
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
|
||
import torch
|
||
import yaml
|
||
from openai import OpenAI
|
||
from PIL import Image
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 设备检测与适配
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _init_xpu():
|
||
"""尝试初始化 Intel XPU 支持(需要 intel-extension-for-pytorch)。"""
|
||
try:
|
||
import intel_extension_for_pytorch as ipex # noqa: F401
|
||
return True
|
||
except ImportError:
|
||
return False
|
||
|
||
|
||
def resolve_device(configured_device: str) -> str:
|
||
"""根据配置和硬件可用性,决定实际使用的推理设备。
|
||
|
||
优先级: 用户配置 > auto 自动检测
|
||
auto 检测顺序: cuda > xpu > mps > cpu
|
||
"""
|
||
if configured_device == "auto":
|
||
if torch.cuda.is_available():
|
||
return "cuda"
|
||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||
_init_xpu()
|
||
return "xpu"
|
||
if hasattr(torch, "backends") and hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||
return "mps"
|
||
return "cpu"
|
||
|
||
if configured_device == "xpu":
|
||
if not (hasattr(torch, "xpu") and torch.xpu.is_available()):
|
||
print("警告: 配置了 xpu 设备但未检测到 Intel XPU,尝试初始化 IPEX...")
|
||
if not _init_xpu():
|
||
print("错误: 无法加载 intel-extension-for-pytorch,请确认已安装。")
|
||
print("安装命令: pip install intel-extension-for-pytorch")
|
||
sys.exit(1)
|
||
if not torch.xpu.is_available():
|
||
print("错误: IPEX 已加载但仍未检测到 XPU 设备。")
|
||
sys.exit(1)
|
||
else:
|
||
_init_xpu()
|
||
|
||
return configured_device
|
||
|
||
|
||
def get_supported_dtype(device: str, configured_dtype: str) -> torch.dtype:
|
||
"""根据设备返回合适的数据类型。
|
||
|
||
Intel Arc GPU 对 bfloat16 部分算子兼容性不佳,推荐使用 float16。
|
||
"""
|
||
dtype_map = {
|
||
"bfloat16": torch.bfloat16,
|
||
"float16": torch.float16,
|
||
"float32": torch.float32,
|
||
}
|
||
|
||
if configured_dtype == "auto":
|
||
if device == "xpu":
|
||
return torch.float16
|
||
if device in ("cuda", "mps"):
|
||
return torch.bfloat16
|
||
return torch.float32
|
||
|
||
dtype = dtype_map.get(configured_dtype, torch.bfloat16)
|
||
|
||
if device == "xpu" and dtype == torch.bfloat16:
|
||
print("提示: Intel Arc GPU 上 bfloat16 部分算子兼容性不佳,自动切换为 float16")
|
||
return torch.float16
|
||
|
||
return dtype
|
||
|
||
|
||
def create_generator(device: str, seed: int) -> torch.Generator:
|
||
"""为指定设备创建随机数生成器。"""
|
||
if device == "xpu":
|
||
return torch.Generator("xpu").manual_seed(seed)
|
||
if device == "cuda":
|
||
return torch.Generator("cuda").manual_seed(seed)
|
||
return torch.Generator().manual_seed(seed)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 配置加载
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def load_config(config_path: str = "config.yaml") -> dict:
|
||
with open(config_path, "r", encoding="utf-8") as f:
|
||
cfg = yaml.safe_load(f)
|
||
|
||
api_key = os.environ.get("LLM_API_KEY") or cfg["llm"].get("api_key", "")
|
||
cfg["llm"]["api_key"] = api_key
|
||
return cfg
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# LLM 古诗词分析
|
||
# ---------------------------------------------------------------------------
|
||
|
||
SYSTEM_PROMPT = """\
|
||
# Role(角色设定)
|
||
|
||
你是一位顶级的中国古典文学泰斗,同时也是一位精通 AI 文本到图像生成(Text-to-Image)\
|
||
底层逻辑的顶级提示词工程师(Prompt Engineer)。\
|
||
你对中国古诗词中的"意境"、"留白"、"虚实相生"有极其深刻的理解,\
|
||
并且知道如何将这些抽象的美学概念转化为扩散模型(Diffusion Models)能够精准识别的\
|
||
视觉特征参数(如光影、材质、构图、渲染引擎词汇)。
|
||
|
||
# Objective(工作目标)
|
||
|
||
你的任务是接收用户输入的古诗词,严格按照"四段式思维链"将其转化为最高质量的图像生成提示词。\
|
||
你需要具备探索长诗或多句诗词连贯多图意象的能力,\
|
||
确保最终生成的单张或多张分镜图像能够完美传达原诗的意境,而不只是生硬的元素堆砌。
|
||
|
||
# Workflow(强制执行四段式思维链)
|
||
|
||
对于用户的每一次输入,你必须严格按顺序在内部执行以下四个步骤,缺一不可:
|
||
|
||
## 第一步:意境与分镜逻辑判断
|
||
|
||
重要分析全诗的时空连贯性:
|
||
- 如果全诗描绘的是同一时间、同一地点的统一场景,生成【单幅画面】
|
||
- 如果诗句间存在明显的视角切换(如远景切特写)、时间推移(如白天到黑夜)或场景跳跃,\
|
||
按内在逻辑拆分为 2 到 4 幅画面的【分镜序列】
|
||
- 意境连贯的相邻诗句应合并为一幅,避免碎片化
|
||
|
||
## 第二步:意境深度解析
|
||
|
||
针对每一个分镜(或单幅画面),分析:
|
||
- 核心情感基调(苍凉悲壮 / 空灵婉约 / 萧瑟肃杀 / 雄浑壮阔 / 闲适恬淡 / 凄婉哀怨等)
|
||
- 季节时间与天气状态
|
||
- "意境"类型与情感张力
|
||
|
||
## 第三步:现代文视觉转义
|
||
|
||
将每一个分镜扩写为极具画面感的现代文视觉脚本。\
|
||
你必须大胆发挥想象力,补全诗句中省略的视觉细节,明确写出:
|
||
- **主体景物**:人物姿态、动作、表情、服饰;核心景物的具体形态
|
||
- **配景与地理环境**:山川、水域、植被、建筑等空间层次
|
||
- **光线条件**:斜阳逆光、清冷月光、破晓微光、黄昏余晖等
|
||
- **天气效果**:晨雾弥漫、细雨如织、大雪纷飞、长风浩荡等
|
||
- **画面构图**:大远景 / 中景 / 特写 / 俯瞰 / 平视等
|
||
|
||
## 第四步:图像生成 Prompt 生成
|
||
|
||
基于第三步的现代文视觉脚本,为每一个分镜生成精确的图像 Prompt。
|
||
|
||
### Prompt 结构(必须遵循)
|
||
|
||
每个 Prompt 必须涵盖以下六大要素,按顺序自然融合为一段连贯流畅的描述文字:
|
||
1. 画面主体:核心人物 / 景物及其状态
|
||
2. 环境背景:空间层次、地理环境、建筑植被
|
||
3. 场景光影:具体光源、光线方向、明暗对比
|
||
4. 气候与氛围:天气、季节、情感色彩
|
||
5. 艺术风格与媒介:中国传统画风关键词 + 媒介质感
|
||
6. 图像质量词:masterpiece, 8k resolution, highly detailed 等
|
||
|
||
【极其重要】最终输出的 prompt 和 prompt_en 必须是自然流畅的连续段落,\
|
||
绝对不要使用方括号 [] 标注要素名称,不要出现类似"[画面主体:...]"的格式标签。\
|
||
六大要素是你内部的组织逻辑,输出时必须将它们无缝融合为一段完整的、富有画面感的描述。
|
||
|
||
### Prompt 长度要求
|
||
|
||
Z-Image-Turbo 非常适合处理包含丰富细节的长描述提示词:
|
||
- 中文 Prompt:80-250 字
|
||
- 英文 Prompt:80-200 词
|
||
|
||
### 风格约束(极其重要)
|
||
|
||
Z-Image-Turbo 不支持负面提示词(Negative Prompts),所有约束必须以正向描述表达。\
|
||
为确保生成"古诗词意境"而非现代写实照片,你必须在 Prompt 末尾加上强有力的风格约束词。\
|
||
以下是可根据诗意灵活选用的风格约束:
|
||
|
||
| 风格 | Prompt 约束词 |
|
||
|------|-------------|
|
||
| 水墨写意 | Traditional Chinese ink wash painting (中国传统水墨画), freehand brushwork (写意), \
|
||
negative space (留白), ethereal atmosphere (空灵的氛围) |
|
||
| 青绿山水 | Traditional Chinese blue-green landscape painting (青绿山水), mineral pigments (石青石绿), \
|
||
golden and jade-like tones (金碧辉煌) |
|
||
| 工笔花鸟 | Chinese meticulous brushwork (工笔), fine detailed rendering (精细渲染), \
|
||
delicate line drawing (细腻勾勒) |
|
||
| 工笔重彩 | Chinese meticulous heavy-color painting (工笔重彩), rich saturated pigments (浓墨重色), \
|
||
elaborate detail (华丽精细) |
|
||
| 文人画 | Chinese literati painting (文人画), poetry-calligraphy-painting unity (诗书画印一体), \
|
||
lofty elegance (意趣高远) |
|
||
| 泼墨大写意 | Splash ink painting (泼墨大写意), bold expressive brushstrokes (墨色淋漓), \
|
||
majestic momentum (气势磅礴) |
|
||
| 浅绛山水 | Light crimson landscape painting (浅绛山水), ochre wash (赭石淡彩), \
|
||
sparse and distant (萧疏清远) |
|
||
|
||
通用质量约束词(所有风格都应附加):\
|
||
masterpiece, 8k resolution, highly detailed, cinematic composition
|
||
|
||
如果用户指定了风格期望,请优先使用用户指定的风格。\
|
||
如果用户未指定风格,请根据诗意自动选择最契合的传统画风。
|
||
|
||
### 中文 Prompt 要求
|
||
- 使用中国传统绘画的专业术语
|
||
- 具体且富有画面感,避免抽象空泛的概念
|
||
- 末尾必须附加风格约束词和质量约束词
|
||
|
||
### 英文 Prompt 要求
|
||
- 中文 Prompt 的忠实翻译与适配,保持相同的画面内容和风格意图
|
||
- 使用对应的英文艺术术语
|
||
- 自然流畅的英文表达,非逐字翻译
|
||
- 末尾必须附加英文风格约束词和质量约束词
|
||
|
||
# Rules(输出规则)
|
||
|
||
严格按照以下 JSON 格式输出结果,不要输出任何与格式无关的文字。\
|
||
四段式思维链的推理过程请融入到对应的 JSON 字段中:
|
||
|
||
```json
|
||
{
|
||
"title": "诗词标题",
|
||
"author": "作者",
|
||
"dynasty": "朝代",
|
||
"genre": "体裁(如:五言绝句、七言律诗、词·水调歌头等)",
|
||
"analysis": "第一步【分镜逻辑判断】的理由 + 第二步【意境深度解析】的综合分析:包含分镜拆分依据、整首诗的意境类型、核心情感基调、时空特征(中文,3-5句话)",
|
||
"images": [
|
||
{
|
||
"scene": "这幅画对应的诗句(原文)",
|
||
"description": "第三步【现代文视觉转义】的完整输出:极具画面感的视觉脚本,包含主体景物、配景、光线、天气、构图等所有视觉细节(中文,100-200字)",
|
||
"style": "选用的画风(中文名称,如:水墨写意、青绿山水、工笔花鸟等)",
|
||
"prompt": "第四步生成的中文 Prompt,自然融合六大要素为连续流畅的段落(禁止使用方括号标注),末尾附加风格约束词和质量词,80-250字",
|
||
"prompt_en": "Step 4 English Prompt, naturally blending all six elements into a fluent paragraph (NO square brackets), ending with style and quality keywords, 80-200 words"
|
||
}
|
||
]
|
||
}
|
||
```\
|
||
"""
|
||
|
||
|
||
def _build_user_message(poem: str, cfg: dict) -> str:
|
||
"""构造发送给 LLM 的用户消息,包含诗词和可选的风格期望。"""
|
||
style_pref = cfg["image"].get("style_preference", "").strip()
|
||
if style_pref:
|
||
style_line = f"【风格期望】:{style_pref}"
|
||
else:
|
||
style_line = "【风格期望】:默认(根据诗意自动选择最契合的传统画风)"
|
||
|
||
return (
|
||
f"请为以下古诗词生成图像提示词:\n\n"
|
||
f"【输入诗词】:\n{poem}\n\n"
|
||
f"{style_line}\n\n"
|
||
f"请严格按照 System Prompt 的要求,首先进行【意境与分镜逻辑判断】,"
|
||
f"随后针对单幅或多幅分镜依次输出对应的【意境深度解析】、"
|
||
f"【现代文视觉转义】以及最终的【图像生成 Prompt】。"
|
||
)
|
||
|
||
|
||
def analyze_poetry(poem: str, cfg: dict) -> dict:
|
||
"""调用 LLM 分析古诗词,返回结构化的图片生成方案。"""
|
||
llm_cfg = cfg["llm"]
|
||
|
||
client = OpenAI(
|
||
base_url=llm_cfg["base_url"],
|
||
api_key=llm_cfg["api_key"],
|
||
timeout=60,
|
||
)
|
||
|
||
style_pref = cfg["image"].get("style_preference", "").strip()
|
||
print(f"\n{'='*60}")
|
||
print("正在调用 LLM 分析古诗词意境(四段式思维链)...")
|
||
print(f"模型: {llm_cfg['model']}")
|
||
if style_pref:
|
||
print(f"风格期望: {style_pref}")
|
||
print(f"{'='*60}\n")
|
||
|
||
user_message = _build_user_message(poem, cfg)
|
||
|
||
response = client.chat.completions.create(
|
||
model=llm_cfg["model"],
|
||
temperature=llm_cfg.get("temperature", 0.7),
|
||
max_tokens=llm_cfg.get("max_tokens", 4096),
|
||
messages=[
|
||
{"role": "system", "content": SYSTEM_PROMPT},
|
||
{"role": "user", "content": user_message},
|
||
],
|
||
)
|
||
|
||
content = response.choices[0].message.content.strip()
|
||
|
||
json_match = re.search(r"```(?:json)?\s*(.*?)```", content, re.DOTALL)
|
||
if json_match:
|
||
content = json_match.group(1).strip()
|
||
|
||
try:
|
||
result = json.loads(content)
|
||
except json.JSONDecodeError:
|
||
json_match = re.search(r"\{.*\}", content, re.DOTALL)
|
||
if json_match:
|
||
result = json.loads(json_match.group())
|
||
else:
|
||
print("LLM 返回内容无法解析为 JSON:")
|
||
print(content)
|
||
sys.exit(1)
|
||
|
||
return result
|
||
|
||
|
||
def display_analysis(analysis: dict) -> None:
|
||
"""友好地展示 LLM 的分析结果。"""
|
||
print(f"\n{'='*60}")
|
||
title = analysis.get("title", "未知")
|
||
author = analysis.get("author", "未知")
|
||
dynasty = analysis.get("dynasty", "")
|
||
genre = analysis.get("genre", "")
|
||
print(f"📜 {title} — {dynasty} · {author} [{genre}]")
|
||
print(f"{'='*60}")
|
||
print(f"\n🔍 意境分析:{analysis.get('analysis', '')}\n")
|
||
|
||
for i, img in enumerate(analysis["images"], 1):
|
||
print(f"{'─'*50}")
|
||
print(f"🖼 第 {i} 幅 | {img['scene']}")
|
||
print(f" 画风选择:{img.get('style', '未指定')}")
|
||
print(f" 中文描述:{img['description']}")
|
||
print(f" Prompt(zh):{img['prompt'][:120]}...")
|
||
if img.get("prompt_en"):
|
||
print(f" Prompt(en):{img['prompt_en'][:120]}...")
|
||
|
||
print(f"\n共 {len(analysis['images'])} 幅画面\n")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 尺寸预设
|
||
# ---------------------------------------------------------------------------
|
||
|
||
SIZE_PRESETS: dict[str, tuple[int, int]] = {
|
||
"square": (1024, 1024),
|
||
"phone": ( 576, 1024),
|
||
"phone_hd": ( 768, 1344),
|
||
"desktop": (1024, 576),
|
||
"desktop_hd": (1344, 768),
|
||
"ultrawide": (1536, 640),
|
||
}
|
||
|
||
|
||
def resolve_image_size(img_cfg: dict) -> tuple[int, int]:
|
||
"""根据 size_preset 或 height/width 配置,返回 (width, height)。"""
|
||
preset = img_cfg.get("size_preset", "").strip().lower()
|
||
if preset and preset != "custom" and preset in SIZE_PRESETS:
|
||
w, h = SIZE_PRESETS[preset]
|
||
return w, h
|
||
return img_cfg.get("width", 1024), img_cfg.get("height", 1024)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Z-Image-Turbo 本地图片生成
|
||
# ---------------------------------------------------------------------------
|
||
|
||
HF_REPO = "Tongyi-MAI/Z-Image-Turbo"
|
||
|
||
# 从 HuggingFace 仓库下载的小型配置文件(首次需要网络,之后自动缓存)
|
||
_HF_CONFIG_FILES = [
|
||
"model_index.json",
|
||
"scheduler/scheduler_config.json",
|
||
"tokenizer/merges.txt",
|
||
"tokenizer/tokenizer_config.json",
|
||
"tokenizer/vocab.json",
|
||
"text_encoder/config.json",
|
||
"text_encoder/generation_config.json",
|
||
"transformer/config.json",
|
||
"vae/config.json",
|
||
]
|
||
|
||
|
||
def _force_link(src: Path, dst: Path) -> None:
|
||
"""创建从 dst 指向 src 的链接,兼容 Windows 无管理员权限的场景。
|
||
|
||
优先级: 符号链接 → 硬链接 → 复制文件
|
||
- 符号链接在 Windows 下需要管理员权限或开启开发者模式
|
||
- 硬链接无需特权但要求 src 和 dst 在同一驱动器
|
||
- 以上均失败时回退到复制(大文件会较慢,但保证可用)
|
||
"""
|
||
src = Path(src).resolve()
|
||
dst = Path(dst)
|
||
if dst.exists() or dst.is_symlink():
|
||
dst.unlink()
|
||
|
||
# 1. 尝试符号链接
|
||
try:
|
||
dst.symlink_to(src)
|
||
return
|
||
except OSError:
|
||
pass
|
||
|
||
# 2. 尝试硬链接(要求同一驱动器/文件系统)
|
||
try:
|
||
os.link(str(src), str(dst))
|
||
return
|
||
except OSError:
|
||
pass
|
||
|
||
# 3. 回退到复制
|
||
print(f" 提示: 无法创建链接,正在复制文件: {src.name}(可能较慢)")
|
||
shutil.copy2(str(src), str(dst))
|
||
|
||
|
||
def _is_comfyui_mode(cfg: dict) -> bool:
|
||
"""判断是否配置了 ComfyUI 拆分文件模式。"""
|
||
comfyui = cfg["image"].get("comfyui", {})
|
||
return bool(
|
||
comfyui.get("text_encoder")
|
||
and comfyui.get("transformer")
|
||
and comfyui.get("vae")
|
||
)
|
||
|
||
|
||
def _is_openvino_mode(cfg: dict) -> bool:
|
||
"""判断是否配置了 OpenVINO 推理模式。"""
|
||
return bool(cfg["image"].get("openvino", {}).get("model_path"))
|
||
|
||
|
||
def _load_pipeline_openvino(cfg: dict):
|
||
"""使用 OpenVINO 加载 Z-Image-Turbo pipeline。
|
||
|
||
需要预先通过 optimum-cli 导出 OpenVINO IR 模型:
|
||
optimum-cli export openvino --model Tongyi-MAI/Z-Image-Turbo \\
|
||
--weight-format int8 z-image-turbo-ov
|
||
"""
|
||
from optimum.intel import OVZImagePipeline
|
||
|
||
ov_cfg = cfg["image"]["openvino"]
|
||
model_path = ov_cfg["model_path"]
|
||
ov_device = ov_cfg.get("device", "GPU")
|
||
|
||
print(f"模式: OpenVINO 推理")
|
||
print(f" 模型路径 : {model_path}")
|
||
print(f" OV 设备 : {ov_device}")
|
||
|
||
if not Path(model_path).exists():
|
||
print(f"错误: OpenVINO 模型目录不存在: {model_path}")
|
||
print("请先使用 optimum-cli 导出模型:")
|
||
print(f" optimum-cli export openvino --model {HF_REPO} --weight-format int8 {model_path}")
|
||
sys.exit(1)
|
||
|
||
pipe = OVZImagePipeline.from_pretrained(model_path, device=ov_device)
|
||
return pipe
|
||
|
||
|
||
def _build_hf_layout_from_comfyui(cfg: dict) -> str:
|
||
"""从 ComfyUI 拆分文件构建 HuggingFace 兼容的目录布局。
|
||
|
||
原理:下载 HuggingFace 仓库中的微型配置文件(JSON/txt,共计 < 100KB),
|
||
然后创建指向 ComfyUI 权重文件的符号链接,最终得到一个
|
||
`ZImagePipeline.from_pretrained()` 可直接加载的目录。
|
||
"""
|
||
from huggingface_hub import hf_hub_download
|
||
|
||
comfyui = cfg["image"]["comfyui"]
|
||
te_path = Path(comfyui["text_encoder"]).resolve()
|
||
tf_path = Path(comfyui["transformer"]).resolve()
|
||
vae_path = Path(comfyui["vae"]).resolve()
|
||
|
||
for name, p in [("text_encoder", te_path), ("transformer", tf_path), ("vae", vae_path)]:
|
||
if not p.exists():
|
||
print(f"错误: ComfyUI {name} 文件不存在: {p}")
|
||
sys.exit(1)
|
||
|
||
cache_dir = Path(".cache") / "comfyui_hf_layout"
|
||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
print("正在准备 HuggingFace 兼容目录结构(仅首次需下载配置文件)...")
|
||
|
||
for rel_path in _HF_CONFIG_FILES:
|
||
dest = cache_dir / rel_path
|
||
if not dest.exists():
|
||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||
src = hf_hub_download(HF_REPO, rel_path)
|
||
shutil.copy2(src, dest)
|
||
|
||
# 链接权重文件 —— text_encoder
|
||
te_link = cache_dir / "text_encoder" / "model.safetensors"
|
||
_force_link(te_path, te_link)
|
||
# 删除分片索引(如果存在),因为 ComfyUI 的文件是单一非分片文件
|
||
shard_idx = cache_dir / "text_encoder" / "model.safetensors.index.json"
|
||
if shard_idx.exists():
|
||
shard_idx.unlink()
|
||
|
||
# 链接权重文件 —— transformer
|
||
tf_link = cache_dir / "transformer" / "diffusion_pytorch_model.safetensors"
|
||
_force_link(tf_path, tf_link)
|
||
|
||
# 链接权重文件 —— vae
|
||
vae_link = cache_dir / "vae" / "diffusion_pytorch_model.safetensors"
|
||
_force_link(vae_path, vae_link)
|
||
|
||
print(f"目录结构已就绪: {cache_dir}")
|
||
return str(cache_dir)
|
||
|
||
|
||
|
||
def _load_pipeline_comfyui(cfg: dict, device: str, torch_dtype: torch.dtype):
|
||
"""从 ComfyUI 拆分文件加载 pipeline(使用逐组件加载方式,仅支持 safetensors)。"""
|
||
from diffusers import (
|
||
AutoencoderKL,
|
||
FlowMatchEulerDiscreteScheduler,
|
||
ZImagePipeline,
|
||
ZImageTransformer2DModel,
|
||
)
|
||
from transformers import AutoTokenizer, Qwen3Model
|
||
|
||
comfyui = cfg["image"]["comfyui"]
|
||
te_path = comfyui["text_encoder"]
|
||
tf_path = comfyui["transformer"]
|
||
vae_path = comfyui["vae"]
|
||
|
||
print("模式: ComfyUI 拆分文件加载")
|
||
print(f" Text Encoder : {te_path}")
|
||
print(f" Transformer : {tf_path}")
|
||
print(f" VAE : {vae_path}")
|
||
|
||
print(" 加载 Scheduler & Tokenizer(配置来自 HuggingFace 缓存)...")
|
||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(HF_REPO, subfolder="scheduler")
|
||
tokenizer = AutoTokenizer.from_pretrained(HF_REPO, subfolder="tokenizer")
|
||
|
||
print(" 加载 Transformer...")
|
||
transformer = ZImageTransformer2DModel.from_single_file(
|
||
tf_path,
|
||
config=HF_REPO,
|
||
subfolder="transformer",
|
||
torch_dtype=torch_dtype,
|
||
)
|
||
|
||
print(" 加载 VAE...")
|
||
vae = AutoencoderKL.from_single_file(
|
||
vae_path,
|
||
config=HF_REPO,
|
||
subfolder="vae",
|
||
torch_dtype=torch_dtype,
|
||
)
|
||
|
||
print(" 加载 Text Encoder (Qwen3 4B)...")
|
||
te_config_path = hf_hub_download_cached(HF_REPO, "text_encoder/config.json")
|
||
te_gen_config_path = hf_hub_download_cached(HF_REPO, "text_encoder/generation_config.json")
|
||
|
||
te_parent_dir = str(Path(te_path).resolve().parent)
|
||
with tempfile.TemporaryDirectory(dir=te_parent_dir) as tmpdir:
|
||
shutil.copy2(te_config_path, os.path.join(tmpdir, "config.json"))
|
||
shutil.copy2(te_gen_config_path, os.path.join(tmpdir, "generation_config.json"))
|
||
_force_link(Path(te_path), Path(tmpdir) / "model.safetensors")
|
||
text_encoder = Qwen3Model.from_pretrained(tmpdir, torch_dtype=torch_dtype)
|
||
|
||
pipe = ZImagePipeline(
|
||
scheduler=scheduler,
|
||
vae=vae,
|
||
text_encoder=text_encoder,
|
||
tokenizer=tokenizer,
|
||
transformer=transformer,
|
||
)
|
||
return pipe
|
||
|
||
|
||
def hf_hub_download_cached(repo_id: str, filename: str) -> str:
|
||
"""下载 HuggingFace 仓库中的文件(自动缓存)。"""
|
||
from huggingface_hub import hf_hub_download
|
||
return hf_hub_download(repo_id, filename)
|
||
|
||
|
||
def load_pipeline(cfg: dict):
|
||
"""加载 Z-Image-Turbo pipeline。自动适配 OpenVINO / HuggingFace / ComfyUI 格式。"""
|
||
img_cfg = cfg["image"]
|
||
|
||
# OpenVINO 模式:由 optimum.intel 管理设备,无需手动 resolve_device
|
||
if _is_openvino_mode(cfg):
|
||
ov_device = img_cfg["openvino"].get("device", "GPU")
|
||
print(f"\n{'='*60}")
|
||
print("正在加载 Z-Image-Turbo 模型 (OpenVINO)...")
|
||
print(f"OpenVINO 设备: {ov_device}")
|
||
print(f"{'='*60}\n")
|
||
|
||
pipe = _load_pipeline_openvino(cfg)
|
||
cfg["_resolved_device"] = "cpu"
|
||
cfg["_openvino_mode"] = True
|
||
return pipe
|
||
|
||
from diffusers import ZImagePipeline
|
||
|
||
device = resolve_device(img_cfg.get("device", "auto"))
|
||
torch_dtype = get_supported_dtype(device, img_cfg.get("torch_dtype", "auto"))
|
||
|
||
print(f"\n{'='*60}")
|
||
print("正在加载 Z-Image-Turbo 模型...")
|
||
print(f"推理设备: {device}")
|
||
print(f"数据类型: {torch_dtype}")
|
||
print(f"{'='*60}\n")
|
||
|
||
if _is_comfyui_mode(cfg):
|
||
pipe = _load_pipeline_comfyui(cfg, device, torch_dtype)
|
||
else:
|
||
model_id = img_cfg["model_id"]
|
||
print(f"模式: HuggingFace 标准加载")
|
||
print(f"模型路径: {model_id}")
|
||
pipe = ZImagePipeline.from_pretrained(
|
||
model_id,
|
||
torch_dtype=torch_dtype,
|
||
low_cpu_mem_usage=False,
|
||
)
|
||
|
||
raw_offload = str(img_cfg.get("enable_cpu_offload", "false")).strip().lower()
|
||
offload_mode = {
|
||
"false": None, "0": None, "no": None, "off": None,
|
||
"true": "model", "1": "model", "yes": "model", "on": "model",
|
||
"model": "model",
|
||
"sequential": "sequential",
|
||
}.get(raw_offload, None)
|
||
|
||
if offload_mode and device in ("cuda", "xpu", "mps"):
|
||
if offload_mode == "sequential":
|
||
print("启用 Sequential CPU Offload: 逐层搬入显卡,最省显存但较慢")
|
||
pipe.enable_sequential_cpu_offload(device=device)
|
||
else:
|
||
print("启用 Model CPU Offload: 组件级按需加载到显卡")
|
||
pipe.enable_model_cpu_offload(device=device)
|
||
else:
|
||
if device != "cpu" and not offload_mode:
|
||
print("提示: 所有模型将同时加载到显卡,如显存不足请在配置中开启 enable_cpu_offload")
|
||
pipe.to(device)
|
||
|
||
if device not in ("xpu", "cpu"):
|
||
attn_backend = img_cfg.get("attention_backend", "sdpa")
|
||
if attn_backend == "flash":
|
||
pipe.transformer.set_attention_backend("flash")
|
||
elif attn_backend == "flash_3":
|
||
pipe.transformer.set_attention_backend("_flash_3")
|
||
|
||
lora_cfg = cfg.get("lora", {})
|
||
if lora_cfg.get("enabled") and lora_cfg.get("path"):
|
||
lora_path = lora_cfg["path"]
|
||
lora_weight = lora_cfg.get("weight", 0.8)
|
||
print(f"正在加载 LoRA: {lora_path} (权重: {lora_weight})")
|
||
pipe.load_lora_weights(lora_path)
|
||
pipe.fuse_lora(lora_scale=lora_weight)
|
||
print("LoRA 加载完成")
|
||
|
||
cfg["_resolved_device"] = device
|
||
cfg["_openvino_mode"] = False
|
||
return pipe
|
||
|
||
|
||
def generate_images(pipe, analysis: dict, cfg: dict) -> list[Path]:
|
||
"""根据分析结果逐一生成图片,返回保存路径列表。"""
|
||
img_cfg = cfg["image"]
|
||
out_cfg = cfg["output"]
|
||
lora_cfg = cfg.get("lora", {})
|
||
device = cfg.get("_resolved_device", "cpu")
|
||
|
||
output_dir = Path(out_cfg.get("dir", "./output"))
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
prefix = out_cfg.get("filename_prefix", "poem")
|
||
width, height = resolve_image_size(img_cfg)
|
||
steps = img_cfg.get("num_inference_steps", 9)
|
||
guidance = img_cfg.get("guidance_scale", 0.0)
|
||
seed = img_cfg.get("seed", -1)
|
||
|
||
trigger_words = ""
|
||
if lora_cfg.get("enabled") and lora_cfg.get("trigger_words"):
|
||
trigger_words = lora_cfg["trigger_words"].strip()
|
||
|
||
preset = img_cfg.get("size_preset", "custom")
|
||
prompt_lang = img_cfg.get("prompt_language", "zh")
|
||
images_per_prompt = max(1, min(10, img_cfg.get("images_per_prompt", 1)))
|
||
print(f"图片尺寸: {width}×{height}" + (f" (预设: {preset})" if preset != "custom" else ""))
|
||
print(f"Prompt 语言: {prompt_lang}")
|
||
if images_per_prompt > 1:
|
||
print(f"每个 prompt 生成 {images_per_prompt} 张图(不同种子)")
|
||
|
||
saved_paths = []
|
||
total = len(analysis["images"])
|
||
|
||
for i, img_info in enumerate(analysis["images"], 1):
|
||
if prompt_lang == "en" and img_info.get("prompt_en"):
|
||
prompt = img_info["prompt_en"]
|
||
else:
|
||
prompt = img_info["prompt"]
|
||
if trigger_words:
|
||
prompt = f"{trigger_words}, {prompt}"
|
||
|
||
print(f"\n[{i}/{total}] 正在生成: {img_info['scene']}")
|
||
print(f" 画风: {img_info.get('style', '未指定')}")
|
||
print(f" Prompt({prompt_lang}): {prompt[:120]}...")
|
||
|
||
for j in range(images_per_prompt):
|
||
variant_offset = i * 100 + j
|
||
actual_seed = (seed + variant_offset) if seed >= 0 else (int(time.time() * 1000) % (2**32) + variant_offset)
|
||
generator = create_generator(device, actual_seed)
|
||
|
||
suffix = chr(ord("a") + j) if images_per_prompt > 1 else ""
|
||
if images_per_prompt > 1:
|
||
print(f" --- 第 {j+1}/{images_per_prompt} 张 (seed={actual_seed}) ---")
|
||
|
||
start_time = time.time()
|
||
|
||
result = pipe(
|
||
prompt=prompt,
|
||
height=height,
|
||
width=width,
|
||
num_inference_steps=steps,
|
||
guidance_scale=guidance,
|
||
generator=generator,
|
||
)
|
||
image: Image.Image = result.images[0]
|
||
|
||
elapsed = time.time() - start_time
|
||
print(f" 生成完成,耗时 {elapsed:.1f}s")
|
||
|
||
img_path = output_dir / f"{prefix}_{i:02d}{suffix}.png"
|
||
image.save(img_path)
|
||
saved_paths.append(img_path)
|
||
print(f" 已保存: {img_path}")
|
||
|
||
if out_cfg.get("save_prompts", True):
|
||
txt_path = output_dir / f"{prefix}_{i:02d}_prompt.txt"
|
||
prompt_zh = img_info["prompt"]
|
||
prompt_en = img_info.get("prompt_en", "")
|
||
txt_path.write_text(
|
||
f"Scene: {img_info['scene']}\n"
|
||
f"Style: {img_info.get('style', '')}\n"
|
||
f"Description: {img_info['description']}\n"
|
||
f"Prompt(zh): {prompt_zh}\n"
|
||
f"Prompt(en): {prompt_en}\n"
|
||
f"Used({prompt_lang}): {prompt}\n",
|
||
encoding="utf-8",
|
||
)
|
||
|
||
return saved_paths
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 主流程
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(
|
||
description="古诗词意境图生成器 — 基于 LLM 分析 + Z-Image-Turbo 生成"
|
||
)
|
||
parser.add_argument(
|
||
"-c", "--config",
|
||
default="config.yaml",
|
||
help="配置文件路径(默认: config.yaml)",
|
||
)
|
||
parser.add_argument(
|
||
"-p", "--poem",
|
||
type=str,
|
||
default=None,
|
||
help="直接传入古诗词文本(如不指定则交互式输入)",
|
||
)
|
||
parser.add_argument(
|
||
"--analyze-only",
|
||
action="store_true",
|
||
help="仅进行 LLM 分析,不生成图片",
|
||
)
|
||
parser.add_argument(
|
||
"-o", "--output",
|
||
type=str,
|
||
default=None,
|
||
help="覆盖输出目录",
|
||
)
|
||
args = parser.parse_args()
|
||
|
||
cfg = load_config(args.config)
|
||
|
||
if args.output:
|
||
cfg["output"]["dir"] = args.output
|
||
else:
|
||
now = datetime.now()
|
||
date_dir = now.strftime("%Y-%m-%d")
|
||
time_dir = now.strftime("%H-%M-%S")
|
||
cfg["output"]["dir"] = str(Path(cfg["output"].get("dir", "./output")) / date_dir / time_dir)
|
||
|
||
if args.poem:
|
||
poem = args.poem
|
||
else:
|
||
print("请输入古诗词(输入空行结束):")
|
||
lines = []
|
||
while True:
|
||
line = input()
|
||
if line.strip() == "":
|
||
break
|
||
lines.append(line)
|
||
poem = "\n".join(lines)
|
||
|
||
if not poem.strip():
|
||
print("未输入任何内容,退出。")
|
||
sys.exit(0)
|
||
|
||
print(f"\n📝 输入的诗词:\n{poem}")
|
||
|
||
analysis = analyze_poetry(poem, cfg)
|
||
display_analysis(analysis)
|
||
|
||
output_dir = Path(cfg["output"].get("dir", "./output"))
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
analysis_path = output_dir / "analysis.json"
|
||
analysis_path.write_text(
|
||
json.dumps(analysis, ensure_ascii=False, indent=2),
|
||
encoding="utf-8",
|
||
)
|
||
print(f"分析结果已保存: {analysis_path}")
|
||
|
||
if args.analyze_only:
|
||
print("\n已完成分析(--analyze-only 模式),跳过图片生成。")
|
||
return
|
||
|
||
if _is_openvino_mode(cfg):
|
||
ov_device = cfg["image"]["openvino"].get("device", "GPU")
|
||
print(f"\n🖥 推理模式: OpenVINO ({ov_device})")
|
||
if ov_device.upper() == "GPU" and hasattr(torch, "xpu") and torch.xpu.is_available():
|
||
print(f" Intel XPU: {torch.xpu.get_device_name(0)}")
|
||
else:
|
||
device = resolve_device(cfg["image"].get("device", "auto"))
|
||
print(f"\n🖥 推理设备: {device}")
|
||
if device == "xpu":
|
||
print(f" Intel XPU: {torch.xpu.get_device_name(0)}")
|
||
print(f" 显存: {torch.xpu.get_device_properties(0).total_memory / 1024**3:.1f} GB")
|
||
elif device == "cuda":
|
||
print(f" CUDA GPU: {torch.cuda.get_device_name(0)}")
|
||
print(f" 显存: {torch.cuda.get_device_properties(0).total_mem / 1024**3:.1f} GB")
|
||
|
||
pipe = load_pipeline(cfg)
|
||
saved = generate_images(pipe, analysis, cfg)
|
||
|
||
print(f"\n{'='*60}")
|
||
print(f"全部完成!共生成 {len(saved)} 幅图片:")
|
||
for p in saved:
|
||
print(f" 📁 {p}")
|
||
print(f"{'='*60}\n")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|