843 lines
32 KiB
Python
843 lines
32 KiB
Python
"""
|
||
古诗词意境图生成器
|
||
将中国古典诗词通过 LLM 分析拆解为多个意境画面,
|
||
再使用 Z-Image-Turbo 本地模型逐一生成高质量图片。
|
||
"""
|
||
|
||
import argparse
|
||
import json
|
||
import os
|
||
import re
|
||
import shutil
|
||
import sys
|
||
import tempfile
|
||
import time
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
|
||
import torch
|
||
import yaml
|
||
from openai import OpenAI
|
||
from PIL import Image
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 设备检测与适配
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _init_xpu():
|
||
"""尝试初始化 Intel XPU 支持(需要 intel-extension-for-pytorch)。"""
|
||
try:
|
||
import intel_extension_for_pytorch as ipex # noqa: F401
|
||
return True
|
||
except ImportError:
|
||
return False
|
||
|
||
|
||
def resolve_device(configured_device: str) -> str:
|
||
"""根据配置和硬件可用性,决定实际使用的推理设备。
|
||
|
||
优先级: 用户配置 > auto 自动检测
|
||
auto 检测顺序: cuda > xpu > mps > cpu
|
||
"""
|
||
if configured_device == "auto":
|
||
if torch.cuda.is_available():
|
||
return "cuda"
|
||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||
_init_xpu()
|
||
return "xpu"
|
||
if hasattr(torch, "backends") and hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||
return "mps"
|
||
return "cpu"
|
||
|
||
if configured_device == "xpu":
|
||
if not (hasattr(torch, "xpu") and torch.xpu.is_available()):
|
||
print("警告: 配置了 xpu 设备但未检测到 Intel XPU,尝试初始化 IPEX...")
|
||
if not _init_xpu():
|
||
print("错误: 无法加载 intel-extension-for-pytorch,请确认已安装。")
|
||
print("安装命令: pip install intel-extension-for-pytorch")
|
||
sys.exit(1)
|
||
if not torch.xpu.is_available():
|
||
print("错误: IPEX 已加载但仍未检测到 XPU 设备。")
|
||
sys.exit(1)
|
||
else:
|
||
_init_xpu()
|
||
|
||
return configured_device
|
||
|
||
|
||
def get_supported_dtype(device: str, configured_dtype: str) -> torch.dtype:
|
||
"""根据设备返回合适的数据类型。
|
||
|
||
Intel Arc GPU 对 bfloat16 部分算子兼容性不佳,推荐使用 float16。
|
||
"""
|
||
dtype_map = {
|
||
"bfloat16": torch.bfloat16,
|
||
"float16": torch.float16,
|
||
"float32": torch.float32,
|
||
}
|
||
|
||
if configured_dtype == "auto":
|
||
if device == "xpu":
|
||
return torch.float16
|
||
if device in ("cuda", "mps"):
|
||
return torch.bfloat16
|
||
return torch.float32
|
||
|
||
dtype = dtype_map.get(configured_dtype, torch.bfloat16)
|
||
|
||
if device == "xpu" and dtype == torch.bfloat16:
|
||
print("提示: Intel Arc GPU 上 bfloat16 部分算子兼容性不佳,自动切换为 float16")
|
||
return torch.float16
|
||
|
||
return dtype
|
||
|
||
|
||
def create_generator(device: str, seed: int) -> torch.Generator:
|
||
"""为指定设备创建随机数生成器。"""
|
||
if device == "xpu":
|
||
return torch.Generator("xpu").manual_seed(seed)
|
||
if device == "cuda":
|
||
return torch.Generator("cuda").manual_seed(seed)
|
||
return torch.Generator().manual_seed(seed)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 配置加载
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def load_config(config_path: str = "config.yaml") -> dict:
|
||
with open(config_path, "r", encoding="utf-8") as f:
|
||
cfg = yaml.safe_load(f)
|
||
|
||
api_key = os.environ.get("LLM_API_KEY") or cfg["llm"].get("api_key", "")
|
||
cfg["llm"]["api_key"] = api_key
|
||
return cfg
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# LLM 古诗词分析
|
||
# ---------------------------------------------------------------------------
|
||
|
||
SYSTEM_PROMPT = """\
|
||
# Role
|
||
|
||
你是一位精通中国古典诗词和 AI 绘画提示词(Prompt)构建的艺术总监。你的任务是将用户输入的古诗词转化为 **Z-Image-Turbo** 能够完美理解的画面描述词。
|
||
|
||
# 解析与构建规则(须在内部遵循,不得单独输出分析过程)
|
||
|
||
1. **【意象提取】**:精准提取诗词中的核心意象、季节、时间、天气、动植物和人物状态。
|
||
2. **【风格指定】**:默认采用高质量的东方古典美学风格;根据诗词意境在下列类别中**为每个分镜的两种方案**各选一种**不同**的取向(两方案不得同质化):\
|
||
中国传统水墨画、工笔画、青绿山水、宋代院体画、敦煌壁画、新中式 3D/CG 唯美插画等。
|
||
3. **【画面构建】**:每条 `prompt` / `prompt_en` 必须是一段连续自然语言,包含主体、背景、构图方式(如远景、特写、留白构图)、光影(如清晨丁达尔光、皎洁月光、薄暮冷光)和色彩倾向;**禁止使用方括号 [] 或标签式分段**。
|
||
4. **【画质增强】**:在每条提示词结尾加上提升画质的关键词(如:8k 分辨率,极致细节,电影级光影,绝美意境,大师之作,壁纸级别);英文 `prompt_en` 结尾使用对等的英文质量词(如 masterpiece, 8k resolution, highly detailed, cinematic lighting)。
|
||
5. **【分镜与 JSON】**:若全诗时空统一则单幅;若有视角切换、时间推移或场景跳跃则拆为 2 至 4 幅分镜。\
|
||
**对外仅输出下方 JSON**,JSON 外不得有任何字符;JSON 内的 `prompt` 与 `prompt_en` 仅含可直接用于生图的纯净文本,不得包含分析、说明性前缀或多余标点。\
|
||
`analysis` 用一两句话概括分镜依据与整体意境即可;`description` 为对应分镜的简要画面说明(供归档,同样写入 JSON)。
|
||
|
||
# Output Format(仅输出此 JSON,无其他内容)
|
||
|
||
```json
|
||
{
|
||
"title": "诗词标题",
|
||
"author": "作者",
|
||
"dynasty": "朝代",
|
||
"genre": "体裁",
|
||
"analysis": "一两句话:分镜依据与整体意境。",
|
||
"images": [
|
||
{
|
||
"scene": "对应诗句原文",
|
||
"description": "该分镜的简要画面说明。",
|
||
"variants": [
|
||
{
|
||
"style": "本方案画风名称(与另一方案不同类)",
|
||
"style_rationale": "一句话说明为何该风格贴此分镜",
|
||
"prompt": "中文:风格句 + 完整画面描述 + 画质词;一段连续文字,无方括号。",
|
||
"prompt_en": "English: same scene intent, style keywords + full scene + quality tags; one fluent paragraph."
|
||
},
|
||
{
|
||
"style": "第二套画风名称",
|
||
"style_rationale": "一句话",
|
||
"prompt": "同上",
|
||
"prompt_en": "同上"
|
||
}
|
||
]
|
||
}
|
||
]
|
||
}
|
||
```
|
||
|
||
**示例(风格与篇幅仅供参考,勿照抄)** — 输入:大漠孤烟直,长河落日圆。\
|
||
输出中的 `prompt` 应类似:中国传统青绿山水与新中式 CG 结合风格。一望无际的浩瀚大漠,壮阔的黄昏景象。一条笔直的孤烟直冲云霄,远处的黄河蜿蜒流淌,波光粼粼,一轮巨大的、红彤彤的落日悬挂在长河尽头。画面以暖黄色和橘红色为主色调,充满苍凉与壮美的边塞意境,大景深,广角全景构图,电影级光影,8k 分辨率,极致细节,绝美壁纸级别。
|
||
|
||
**重要**:`images` 中每一项必须包含恰好 **2** 个 `variants`;`prompt` 与 `prompt_en` 为可直接用于 Z-Image-Turbo 的最终提示词正文。\
|
||
"""
|
||
|
||
|
||
def _build_user_message(poem: str, cfg: dict) -> str:
|
||
"""构造发送给 LLM 的用户消息,包含诗词和可选的风格期望。"""
|
||
style_pref = cfg["image"].get("style_preference", "").strip()
|
||
if style_pref:
|
||
style_line = f"【风格期望】:{style_pref}"
|
||
else:
|
||
style_line = "【风格期望】:默认(根据诗意自动选择最契合的传统画风)"
|
||
|
||
return (
|
||
f"请为以下古诗词生成图像提示词:\n\n"
|
||
f"【输入诗词】:\n{poem}\n\n"
|
||
f"{style_line}\n\n"
|
||
f"请严格遵循 System Prompt 中的五项规则(意象、风格、画面、画质、纯净 JSON),"
|
||
f"为单幅或多幅分镜各输出 **2 套**不同取向的画风方案(`variants` 数组),仅返回约定的 JSON。"
|
||
)
|
||
|
||
|
||
def analyze_poetry(poem: str, cfg: dict) -> dict:
|
||
"""调用 LLM 分析古诗词,返回结构化的图片生成方案。"""
|
||
llm_cfg = cfg["llm"]
|
||
|
||
client = OpenAI(
|
||
base_url=llm_cfg["base_url"],
|
||
api_key=llm_cfg["api_key"],
|
||
timeout=120,
|
||
)
|
||
|
||
style_pref = cfg["image"].get("style_preference", "").strip()
|
||
print(f"\n{'='*60}")
|
||
print("正在调用 LLM 生成古诗词画面提示词(JSON)...")
|
||
print(f"模型: {llm_cfg['model']}")
|
||
if style_pref:
|
||
print(f"风格期望: {style_pref}")
|
||
print(f"{'='*60}\n")
|
||
|
||
user_message = _build_user_message(poem, cfg)
|
||
|
||
response = client.chat.completions.create(
|
||
model=llm_cfg["model"],
|
||
temperature=llm_cfg.get("temperature", 0.7),
|
||
max_tokens=llm_cfg.get("max_tokens", 4096),
|
||
messages=[
|
||
{"role": "system", "content": SYSTEM_PROMPT},
|
||
{"role": "user", "content": user_message},
|
||
],
|
||
)
|
||
|
||
content = response.choices[0].message.content.strip()
|
||
|
||
json_match = re.search(r"```(?:json)?\s*(.*?)```", content, re.DOTALL)
|
||
if json_match:
|
||
content = json_match.group(1).strip()
|
||
|
||
try:
|
||
result = json.loads(content)
|
||
except json.JSONDecodeError:
|
||
json_match = re.search(r"\{.*\}", content, re.DOTALL)
|
||
if json_match:
|
||
result = json.loads(json_match.group())
|
||
else:
|
||
print("LLM 返回内容无法解析为 JSON:")
|
||
print(content)
|
||
sys.exit(1)
|
||
|
||
return result
|
||
|
||
|
||
def _normalize_scene_variants(img_info: dict, max_variants: int) -> list[tuple[str, dict]]:
|
||
"""从单条分镜解析待生成的画风变体,供绘图循环使用。
|
||
|
||
返回 [(文件名标签如 v01, variant 字典), ...]。
|
||
兼容旧版 JSON(无 variants 数组时退回顶层 prompt / prompt_en)。
|
||
"""
|
||
max_variants = max(1, min(2, int(max_variants)))
|
||
raw = img_info.get("variants")
|
||
collected: list[dict] = []
|
||
if isinstance(raw, list):
|
||
for v in raw:
|
||
if isinstance(v, dict) and (v.get("prompt") or v.get("prompt_en")):
|
||
collected.append(v)
|
||
if collected:
|
||
return [(f"v{idx:02d}", collected[idx - 1]) for idx in range(1, min(len(collected), max_variants) + 1)]
|
||
|
||
legacy = {
|
||
"style": img_info.get("style", ""),
|
||
"style_rationale": "",
|
||
"prompt": img_info.get("prompt", ""),
|
||
"prompt_en": img_info.get("prompt_en", ""),
|
||
}
|
||
if legacy["prompt"] or legacy["prompt_en"]:
|
||
return [("v01", legacy)]
|
||
return []
|
||
|
||
|
||
def display_analysis(analysis: dict) -> None:
|
||
"""友好地展示 LLM 的分析结果。"""
|
||
print(f"\n{'='*60}")
|
||
title = analysis.get("title", "未知")
|
||
author = analysis.get("author", "未知")
|
||
dynasty = analysis.get("dynasty", "")
|
||
genre = analysis.get("genre", "")
|
||
print(f"📜 {title} — {dynasty} · {author} [{genre}]")
|
||
print(f"{'='*60}")
|
||
print(f"\n🔍 意境分析:{analysis.get('analysis', '')}\n")
|
||
|
||
for i, img in enumerate(analysis["images"], 1):
|
||
print(f"{'─'*50}")
|
||
print(f"🖼 第 {i} 幅 | {img['scene']}")
|
||
desc = img.get("description", "")
|
||
if desc:
|
||
print(f" 中文描述:{desc}")
|
||
vlist = img.get("variants")
|
||
if isinstance(vlist, list) and vlist:
|
||
for vi, v in enumerate(vlist, 1):
|
||
print(f" ─ 画风方案 {vi}:{v.get('style', '未指定')}")
|
||
if v.get("style_rationale"):
|
||
print(f" 说明:{v['style_rationale']}")
|
||
zh = v.get("prompt") or ""
|
||
if zh:
|
||
tail = "..." if len(zh) > 120 else ""
|
||
print(f" Prompt(zh):{zh[:120]}{tail}")
|
||
en = v.get("prompt_en") or ""
|
||
if en:
|
||
tail = "..." if len(en) > 120 else ""
|
||
print(f" Prompt(en):{en[:120]}{tail}")
|
||
else:
|
||
print(f" 画风选择:{img.get('style', '未指定')}")
|
||
zh = img.get("prompt") or ""
|
||
if zh:
|
||
print(f" Prompt(zh):{zh[:120]}..." if len(zh) > 120 else f" Prompt(zh):{zh}")
|
||
if img.get("prompt_en"):
|
||
en = img["prompt_en"]
|
||
print(f" Prompt(en):{en[:120]}..." if len(en) > 120 else f" Prompt(en):{en}")
|
||
|
||
n_scenes = len(analysis["images"])
|
||
n_variants = sum(
|
||
len(img["variants"]) if isinstance(img.get("variants"), list) else (1 if img.get("prompt") or img.get("prompt_en") else 0)
|
||
for img in analysis["images"]
|
||
)
|
||
print(f"\n共 {n_scenes} 个分镜;LLM 共给出约 {n_variants} 套画风方案(生成张数受配置 style_variants 与 images_per_prompt 影响)\n")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 尺寸预设
|
||
# ---------------------------------------------------------------------------
|
||
|
||
SIZE_PRESETS: dict[str, tuple[int, int]] = {
|
||
"square": (1024, 1024),
|
||
"phone": ( 576, 1024),
|
||
"phone_hd": ( 768, 1344),
|
||
"desktop": (1024, 576),
|
||
"desktop_hd": (1344, 768),
|
||
"ultrawide": (1536, 640),
|
||
}
|
||
|
||
|
||
def resolve_image_size(img_cfg: dict) -> tuple[int, int]:
|
||
"""根据 size_preset 或 height/width 配置,返回 (width, height)。"""
|
||
preset = img_cfg.get("size_preset", "").strip().lower()
|
||
if preset and preset != "custom" and preset in SIZE_PRESETS:
|
||
w, h = SIZE_PRESETS[preset]
|
||
return w, h
|
||
return img_cfg.get("width", 1024), img_cfg.get("height", 1024)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Z-Image-Turbo 本地图片生成
|
||
# ---------------------------------------------------------------------------
|
||
|
||
HF_REPO = "Tongyi-MAI/Z-Image-Turbo"
|
||
|
||
# 从 HuggingFace 仓库下载的小型配置文件(首次需要网络,之后自动缓存)
|
||
_HF_CONFIG_FILES = [
|
||
"model_index.json",
|
||
"scheduler/scheduler_config.json",
|
||
"tokenizer/merges.txt",
|
||
"tokenizer/tokenizer_config.json",
|
||
"tokenizer/vocab.json",
|
||
"text_encoder/config.json",
|
||
"text_encoder/generation_config.json",
|
||
"transformer/config.json",
|
||
"vae/config.json",
|
||
]
|
||
|
||
|
||
def _force_link(src: Path, dst: Path) -> None:
|
||
"""创建从 dst 指向 src 的链接,兼容 Windows 无管理员权限的场景。
|
||
|
||
优先级: 符号链接 → 硬链接 → 复制文件
|
||
- 符号链接在 Windows 下需要管理员权限或开启开发者模式
|
||
- 硬链接无需特权但要求 src 和 dst 在同一驱动器
|
||
- 以上均失败时回退到复制(大文件会较慢,但保证可用)
|
||
"""
|
||
src = Path(src).resolve()
|
||
dst = Path(dst)
|
||
if dst.exists() or dst.is_symlink():
|
||
dst.unlink()
|
||
|
||
# 1. 尝试符号链接
|
||
try:
|
||
dst.symlink_to(src)
|
||
return
|
||
except OSError:
|
||
pass
|
||
|
||
# 2. 尝试硬链接(要求同一驱动器/文件系统)
|
||
try:
|
||
os.link(str(src), str(dst))
|
||
return
|
||
except OSError:
|
||
pass
|
||
|
||
# 3. 回退到复制
|
||
print(f" 提示: 无法创建链接,正在复制文件: {src.name}(可能较慢)")
|
||
shutil.copy2(str(src), str(dst))
|
||
|
||
|
||
def _is_comfyui_mode(cfg: dict) -> bool:
|
||
"""判断是否配置了 ComfyUI 拆分文件模式。"""
|
||
comfyui = cfg["image"].get("comfyui", {})
|
||
return bool(
|
||
comfyui.get("text_encoder")
|
||
and comfyui.get("transformer")
|
||
and comfyui.get("vae")
|
||
)
|
||
|
||
|
||
def _is_openvino_mode(cfg: dict) -> bool:
|
||
"""判断是否配置了 OpenVINO 推理模式。"""
|
||
return bool(cfg["image"].get("openvino", {}).get("model_path"))
|
||
|
||
|
||
def _load_pipeline_openvino(cfg: dict):
|
||
"""使用 OpenVINO 加载 Z-Image-Turbo pipeline。
|
||
|
||
需要预先通过 optimum-cli 导出 OpenVINO IR 模型:
|
||
optimum-cli export openvino --model Tongyi-MAI/Z-Image-Turbo \\
|
||
--weight-format int8 z-image-turbo-ov
|
||
"""
|
||
from optimum.intel import OVZImagePipeline
|
||
|
||
ov_cfg = cfg["image"]["openvino"]
|
||
model_path = ov_cfg["model_path"]
|
||
ov_device = ov_cfg.get("device", "GPU")
|
||
|
||
print(f"模式: OpenVINO 推理")
|
||
print(f" 模型路径 : {model_path}")
|
||
print(f" OV 设备 : {ov_device}")
|
||
|
||
if not Path(model_path).exists():
|
||
print(f"错误: OpenVINO 模型目录不存在: {model_path}")
|
||
print("请先使用 optimum-cli 导出模型:")
|
||
print(f" optimum-cli export openvino --model {HF_REPO} --weight-format int8 {model_path}")
|
||
sys.exit(1)
|
||
|
||
pipe = OVZImagePipeline.from_pretrained(model_path, device=ov_device)
|
||
return pipe
|
||
|
||
|
||
def _build_hf_layout_from_comfyui(cfg: dict) -> str:
|
||
"""从 ComfyUI 拆分文件构建 HuggingFace 兼容的目录布局。
|
||
|
||
原理:下载 HuggingFace 仓库中的微型配置文件(JSON/txt,共计 < 100KB),
|
||
然后创建指向 ComfyUI 权重文件的符号链接,最终得到一个
|
||
`ZImagePipeline.from_pretrained()` 可直接加载的目录。
|
||
"""
|
||
from huggingface_hub import hf_hub_download
|
||
|
||
comfyui = cfg["image"]["comfyui"]
|
||
te_path = Path(comfyui["text_encoder"]).resolve()
|
||
tf_path = Path(comfyui["transformer"]).resolve()
|
||
vae_path = Path(comfyui["vae"]).resolve()
|
||
|
||
for name, p in [("text_encoder", te_path), ("transformer", tf_path), ("vae", vae_path)]:
|
||
if not p.exists():
|
||
print(f"错误: ComfyUI {name} 文件不存在: {p}")
|
||
sys.exit(1)
|
||
|
||
cache_dir = Path(".cache") / "comfyui_hf_layout"
|
||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
print("正在准备 HuggingFace 兼容目录结构(仅首次需下载配置文件)...")
|
||
|
||
for rel_path in _HF_CONFIG_FILES:
|
||
dest = cache_dir / rel_path
|
||
if not dest.exists():
|
||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||
src = hf_hub_download(HF_REPO, rel_path)
|
||
shutil.copy2(src, dest)
|
||
|
||
# 链接权重文件 —— text_encoder
|
||
te_link = cache_dir / "text_encoder" / "model.safetensors"
|
||
_force_link(te_path, te_link)
|
||
# 删除分片索引(如果存在),因为 ComfyUI 的文件是单一非分片文件
|
||
shard_idx = cache_dir / "text_encoder" / "model.safetensors.index.json"
|
||
if shard_idx.exists():
|
||
shard_idx.unlink()
|
||
|
||
# 链接权重文件 —— transformer
|
||
tf_link = cache_dir / "transformer" / "diffusion_pytorch_model.safetensors"
|
||
_force_link(tf_path, tf_link)
|
||
|
||
# 链接权重文件 —— vae
|
||
vae_link = cache_dir / "vae" / "diffusion_pytorch_model.safetensors"
|
||
_force_link(vae_path, vae_link)
|
||
|
||
print(f"目录结构已就绪: {cache_dir}")
|
||
return str(cache_dir)
|
||
|
||
|
||
|
||
def _load_pipeline_comfyui(cfg: dict, device: str, torch_dtype: torch.dtype):
|
||
"""从 ComfyUI 拆分文件加载 pipeline(使用逐组件加载方式,仅支持 safetensors)。"""
|
||
from diffusers import (
|
||
AutoencoderKL,
|
||
FlowMatchEulerDiscreteScheduler,
|
||
ZImagePipeline,
|
||
ZImageTransformer2DModel,
|
||
)
|
||
from transformers import AutoTokenizer, Qwen3Model
|
||
|
||
comfyui = cfg["image"]["comfyui"]
|
||
te_path = comfyui["text_encoder"]
|
||
tf_path = comfyui["transformer"]
|
||
vae_path = comfyui["vae"]
|
||
|
||
print("模式: ComfyUI 拆分文件加载")
|
||
print(f" Text Encoder : {te_path}")
|
||
print(f" Transformer : {tf_path}")
|
||
print(f" VAE : {vae_path}")
|
||
|
||
print(" 加载 Scheduler & Tokenizer(配置来自 HuggingFace 缓存)...")
|
||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(HF_REPO, subfolder="scheduler")
|
||
tokenizer = AutoTokenizer.from_pretrained(HF_REPO, subfolder="tokenizer")
|
||
|
||
print(" 加载 Transformer...")
|
||
transformer = ZImageTransformer2DModel.from_single_file(
|
||
tf_path,
|
||
config=HF_REPO,
|
||
subfolder="transformer",
|
||
torch_dtype=torch_dtype,
|
||
)
|
||
|
||
print(" 加载 VAE...")
|
||
vae = AutoencoderKL.from_single_file(
|
||
vae_path,
|
||
config=HF_REPO,
|
||
subfolder="vae",
|
||
torch_dtype=torch_dtype,
|
||
)
|
||
|
||
print(" 加载 Text Encoder (Qwen3 4B)...")
|
||
te_config_path = hf_hub_download_cached(HF_REPO, "text_encoder/config.json")
|
||
te_gen_config_path = hf_hub_download_cached(HF_REPO, "text_encoder/generation_config.json")
|
||
|
||
te_parent_dir = str(Path(te_path).resolve().parent)
|
||
with tempfile.TemporaryDirectory(dir=te_parent_dir) as tmpdir:
|
||
shutil.copy2(te_config_path, os.path.join(tmpdir, "config.json"))
|
||
shutil.copy2(te_gen_config_path, os.path.join(tmpdir, "generation_config.json"))
|
||
_force_link(Path(te_path), Path(tmpdir) / "model.safetensors")
|
||
text_encoder = Qwen3Model.from_pretrained(tmpdir, torch_dtype=torch_dtype)
|
||
|
||
pipe = ZImagePipeline(
|
||
scheduler=scheduler,
|
||
vae=vae,
|
||
text_encoder=text_encoder,
|
||
tokenizer=tokenizer,
|
||
transformer=transformer,
|
||
)
|
||
return pipe
|
||
|
||
|
||
def hf_hub_download_cached(repo_id: str, filename: str) -> str:
|
||
"""下载 HuggingFace 仓库中的文件(自动缓存)。"""
|
||
from huggingface_hub import hf_hub_download
|
||
return hf_hub_download(repo_id, filename)
|
||
|
||
|
||
def load_pipeline(cfg: dict):
|
||
"""加载 Z-Image-Turbo pipeline。自动适配 OpenVINO / HuggingFace / ComfyUI 格式。"""
|
||
img_cfg = cfg["image"]
|
||
|
||
# OpenVINO 模式:由 optimum.intel 管理设备,无需手动 resolve_device
|
||
if _is_openvino_mode(cfg):
|
||
ov_device = img_cfg["openvino"].get("device", "GPU")
|
||
print(f"\n{'='*60}")
|
||
print("正在加载 Z-Image-Turbo 模型 (OpenVINO)...")
|
||
print(f"OpenVINO 设备: {ov_device}")
|
||
print(f"{'='*60}\n")
|
||
|
||
pipe = _load_pipeline_openvino(cfg)
|
||
cfg["_resolved_device"] = "cpu"
|
||
cfg["_openvino_mode"] = True
|
||
return pipe
|
||
|
||
from diffusers import ZImagePipeline
|
||
|
||
device = resolve_device(img_cfg.get("device", "auto"))
|
||
torch_dtype = get_supported_dtype(device, img_cfg.get("torch_dtype", "auto"))
|
||
|
||
print(f"\n{'='*60}")
|
||
print("正在加载 Z-Image-Turbo 模型...")
|
||
print(f"推理设备: {device}")
|
||
print(f"数据类型: {torch_dtype}")
|
||
print(f"{'='*60}\n")
|
||
|
||
if _is_comfyui_mode(cfg):
|
||
pipe = _load_pipeline_comfyui(cfg, device, torch_dtype)
|
||
else:
|
||
model_id = img_cfg["model_id"]
|
||
print(f"模式: HuggingFace 标准加载")
|
||
print(f"模型路径: {model_id}")
|
||
pipe = ZImagePipeline.from_pretrained(
|
||
model_id,
|
||
torch_dtype=torch_dtype,
|
||
low_cpu_mem_usage=False,
|
||
)
|
||
|
||
raw_offload = str(img_cfg.get("enable_cpu_offload", "false")).strip().lower()
|
||
offload_mode = {
|
||
"false": None, "0": None, "no": None, "off": None,
|
||
"true": "model", "1": "model", "yes": "model", "on": "model",
|
||
"model": "model",
|
||
"sequential": "sequential",
|
||
}.get(raw_offload, None)
|
||
|
||
if offload_mode and device in ("cuda", "xpu", "mps"):
|
||
if offload_mode == "sequential":
|
||
print("启用 Sequential CPU Offload: 逐层搬入显卡,最省显存但较慢")
|
||
pipe.enable_sequential_cpu_offload(device=device)
|
||
else:
|
||
print("启用 Model CPU Offload: 组件级按需加载到显卡")
|
||
pipe.enable_model_cpu_offload(device=device)
|
||
else:
|
||
if device != "cpu" and not offload_mode:
|
||
print("提示: 所有模型将同时加载到显卡,如显存不足请在配置中开启 enable_cpu_offload")
|
||
pipe.to(device)
|
||
|
||
if device not in ("xpu", "cpu"):
|
||
attn_backend = img_cfg.get("attention_backend", "sdpa")
|
||
if attn_backend == "flash":
|
||
pipe.transformer.set_attention_backend("flash")
|
||
elif attn_backend == "flash_3":
|
||
pipe.transformer.set_attention_backend("_flash_3")
|
||
|
||
lora_cfg = cfg.get("lora", {})
|
||
if lora_cfg.get("enabled") and lora_cfg.get("path"):
|
||
lora_path = lora_cfg["path"]
|
||
lora_weight = lora_cfg.get("weight", 0.8)
|
||
print(f"正在加载 LoRA: {lora_path} (权重: {lora_weight})")
|
||
pipe.load_lora_weights(lora_path)
|
||
pipe.fuse_lora(lora_scale=lora_weight)
|
||
print("LoRA 加载完成")
|
||
|
||
cfg["_resolved_device"] = device
|
||
cfg["_openvino_mode"] = False
|
||
return pipe
|
||
|
||
|
||
def generate_images(pipe, analysis: dict, cfg: dict) -> list[Path]:
|
||
"""根据分析结果逐一生成图片,返回保存路径列表。"""
|
||
img_cfg = cfg["image"]
|
||
out_cfg = cfg["output"]
|
||
lora_cfg = cfg.get("lora", {})
|
||
device = cfg.get("_resolved_device", "cpu")
|
||
|
||
output_dir = Path(out_cfg.get("dir", "./output"))
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
prefix = out_cfg.get("filename_prefix", "poem")
|
||
width, height = resolve_image_size(img_cfg)
|
||
steps = img_cfg.get("num_inference_steps", 9)
|
||
guidance = img_cfg.get("guidance_scale", 0.0)
|
||
seed = img_cfg.get("seed", -1)
|
||
|
||
trigger_words = ""
|
||
if lora_cfg.get("enabled") and lora_cfg.get("trigger_words"):
|
||
trigger_words = lora_cfg["trigger_words"].strip()
|
||
|
||
preset = img_cfg.get("size_preset", "custom")
|
||
prompt_lang = img_cfg.get("prompt_language", "zh")
|
||
images_per_prompt = max(1, min(10, img_cfg.get("images_per_prompt", 1)))
|
||
max_style_variants = max(1, min(2, int(img_cfg.get("style_variants", 2))))
|
||
print(f"图片尺寸: {width}×{height}" + (f" (预设: {preset})" if preset != "custom" else ""))
|
||
print(f"Prompt 语言: {prompt_lang}")
|
||
print(f"每分镜画风方案数: {max_style_variants}(配置项 style_variants,1 或 2)")
|
||
if images_per_prompt > 1:
|
||
print(f"每个 prompt 生成 {images_per_prompt} 张图(不同种子)")
|
||
|
||
saved_paths = []
|
||
total = len(analysis["images"])
|
||
|
||
for i, img_info in enumerate(analysis["images"], 1):
|
||
variant_list = _normalize_scene_variants(img_info, max_style_variants)
|
||
if not variant_list:
|
||
print(f"\n警告: 第 {i}/{total} 幅分镜无有效 prompt,已跳过: {img_info.get('scene', '')}")
|
||
continue
|
||
|
||
print(f"\n[{i}/{total}] 分镜: {img_info['scene']}")
|
||
|
||
if out_cfg.get("save_prompts", True):
|
||
txt_lines = [
|
||
f"Scene: {img_info['scene']}\n",
|
||
f"Description: {img_info.get('description', '')}\n",
|
||
f"Prompt_language_used: {prompt_lang}\n",
|
||
]
|
||
|
||
for vi, (v_label, variant) in enumerate(variant_list):
|
||
if prompt_lang == "en" and variant.get("prompt_en"):
|
||
prompt = variant["prompt_en"]
|
||
else:
|
||
prompt = variant.get("prompt") or variant.get("prompt_en") or ""
|
||
if trigger_words:
|
||
prompt = f"{trigger_words}, {prompt}"
|
||
|
||
st = variant.get("style", "未指定")
|
||
print(f" [{v_label}] 画风: {st}")
|
||
prev = prompt[:120] + ("..." if len(prompt) > 120 else "")
|
||
print(f" Prompt({prompt_lang}): {prev}")
|
||
|
||
if out_cfg.get("save_prompts", True):
|
||
txt_lines.append(f"\n--- {v_label} | Style: {st} ---\n")
|
||
if variant.get("style_rationale"):
|
||
txt_lines.append(f"Rationale: {variant['style_rationale']}\n")
|
||
txt_lines.append(f"Prompt(zh): {variant.get('prompt', '')}\n")
|
||
txt_lines.append(f"Prompt(en): {variant.get('prompt_en', '')}\n")
|
||
txt_lines.append(f"Used({prompt_lang}): {prompt}\n")
|
||
|
||
for j in range(images_per_prompt):
|
||
variant_offset = i * 100 + vi * 17 + j
|
||
actual_seed = (seed + variant_offset) if seed >= 0 else (int(time.time() * 1000) % (2**32) + variant_offset)
|
||
generator = create_generator(device, actual_seed)
|
||
|
||
seed_suffix = chr(ord("a") + j) if images_per_prompt > 1 else ""
|
||
if images_per_prompt > 1:
|
||
print(f" --- 同画风第 {j+1}/{images_per_prompt} 张 (seed={actual_seed}) ---")
|
||
|
||
start_time = time.time()
|
||
|
||
result = pipe(
|
||
prompt=prompt,
|
||
height=height,
|
||
width=width,
|
||
num_inference_steps=steps,
|
||
guidance_scale=guidance,
|
||
generator=generator,
|
||
)
|
||
image: Image.Image = result.images[0]
|
||
|
||
elapsed = time.time() - start_time
|
||
print(f" 生成完成,耗时 {elapsed:.1f}s")
|
||
|
||
img_path = output_dir / f"{prefix}_{i:02d}_{v_label}{seed_suffix}.png"
|
||
image.save(img_path)
|
||
saved_paths.append(img_path)
|
||
print(f" 已保存: {img_path}")
|
||
|
||
if out_cfg.get("save_prompts", True):
|
||
txt_path = output_dir / f"{prefix}_{i:02d}_prompt.txt"
|
||
txt_path.write_text("".join(txt_lines), encoding="utf-8")
|
||
|
||
return saved_paths
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 主流程
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(
|
||
description="古诗词意境图生成器 — 基于 LLM 分析 + Z-Image-Turbo 生成"
|
||
)
|
||
parser.add_argument(
|
||
"-c", "--config",
|
||
default="config.yaml",
|
||
help="配置文件路径(默认: config.yaml)",
|
||
)
|
||
parser.add_argument(
|
||
"-p", "--poem",
|
||
type=str,
|
||
default=None,
|
||
help="直接传入古诗词文本(如不指定则交互式输入)",
|
||
)
|
||
parser.add_argument(
|
||
"--analyze-only",
|
||
action="store_true",
|
||
help="仅进行 LLM 分析,不生成图片",
|
||
)
|
||
parser.add_argument(
|
||
"-o", "--output",
|
||
type=str,
|
||
default=None,
|
||
help="覆盖输出目录",
|
||
)
|
||
args = parser.parse_args()
|
||
|
||
cfg = load_config(args.config)
|
||
|
||
if args.output:
|
||
cfg["output"]["dir"] = args.output
|
||
else:
|
||
now = datetime.now()
|
||
date_dir = now.strftime("%Y-%m-%d")
|
||
time_dir = now.strftime("%H-%M-%S")
|
||
cfg["output"]["dir"] = str(Path(cfg["output"].get("dir", "./output")) / date_dir / time_dir)
|
||
|
||
if args.poem:
|
||
poem = args.poem
|
||
else:
|
||
print("请输入古诗词(输入空行结束):")
|
||
lines = []
|
||
while True:
|
||
line = input()
|
||
if line.strip() == "":
|
||
break
|
||
lines.append(line)
|
||
poem = "\n".join(lines)
|
||
|
||
if not poem.strip():
|
||
print("未输入任何内容,退出。")
|
||
sys.exit(0)
|
||
|
||
print(f"\n📝 输入的诗词:\n{poem}")
|
||
|
||
analysis = analyze_poetry(poem, cfg)
|
||
display_analysis(analysis)
|
||
|
||
output_dir = Path(cfg["output"].get("dir", "./output"))
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
analysis_path = output_dir / "analysis.json"
|
||
analysis_path.write_text(
|
||
json.dumps(analysis, ensure_ascii=False, indent=2),
|
||
encoding="utf-8",
|
||
)
|
||
print(f"分析结果已保存: {analysis_path}")
|
||
|
||
if args.analyze_only:
|
||
print("\n已完成分析(--analyze-only 模式),跳过图片生成。")
|
||
return
|
||
|
||
if _is_openvino_mode(cfg):
|
||
ov_device = cfg["image"]["openvino"].get("device", "GPU")
|
||
print(f"\n🖥 推理模式: OpenVINO ({ov_device})")
|
||
if ov_device.upper() == "GPU" and hasattr(torch, "xpu") and torch.xpu.is_available():
|
||
print(f" Intel XPU: {torch.xpu.get_device_name(0)}")
|
||
else:
|
||
device = resolve_device(cfg["image"].get("device", "auto"))
|
||
print(f"\n🖥 推理设备: {device}")
|
||
if device == "xpu":
|
||
print(f" Intel XPU: {torch.xpu.get_device_name(0)}")
|
||
print(f" 显存: {torch.xpu.get_device_properties(0).total_memory / 1024**3:.1f} GB")
|
||
elif device == "cuda":
|
||
print(f" CUDA GPU: {torch.cuda.get_device_name(0)}")
|
||
print(f" 显存: {torch.cuda.get_device_properties(0).total_mem / 1024**3:.1f} GB")
|
||
|
||
pipe = load_pipeline(cfg)
|
||
saved = generate_images(pipe, analysis, cfg)
|
||
|
||
print(f"\n{'='*60}")
|
||
print(f"全部完成!共生成 {len(saved)} 幅图片:")
|
||
for p in saved:
|
||
print(f" 📁 {p}")
|
||
print(f"{'='*60}\n")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|