Files
ZImageTurbo/prompt_to_image.py
2026-03-30 23:00:06 +08:00

205 lines
6.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
直接使用用户输入的 prompt 调用本地 Z-Image-Turbo 出图。
复用 config.yaml 中的 image / lora / output 等推理配置;
不对 prompt 做 LLM 改写,不自动拼接 LoRA 触发词(触发词请自行写进 prompt
"""
from __future__ import annotations
import argparse
import sys
import time
from datetime import datetime
from pathlib import Path
from poetry_to_image import (
create_generator,
load_config,
load_pipeline,
resolve_image_size,
)
def _read_prompt_from_file(path: Path) -> str:
"""按 UTF-8 原样读取文件,不做 strip 或换行规范化以外的解码。"""
return path.read_bytes().decode("utf-8")
def _collect_prompts(args: argparse.Namespace) -> list[str]:
prompts: list[str] = []
if args.prompts:
prompts.extend(args.prompts)
for fp in args.prompt_files or []:
p = Path(fp)
if not p.is_file():
print(f"错误: 文件不存在: {p}", file=sys.stderr)
sys.exit(1)
prompts.append(_read_prompt_from_file(p))
if not prompts and not sys.stdin.isatty():
prompts.append(sys.stdin.buffer.read().decode("utf-8"))
if not prompts and sys.stdin.isatty():
print("请输入 prompt空行结束")
lines: list[str] = []
while True:
try:
line = input()
except EOFError:
break
if line == "":
break
lines.append(line)
text = "\n".join(lines)
if text:
prompts.append(text)
return prompts
def _generate(
pipe,
prompt: str,
*,
cfg: dict,
index: int,
output_dir: Path,
filename_prefix: str,
) -> list[Path]:
"""对单条 prompt 出图prompt 字符串原样传入 pipeline。"""
img_cfg = cfg["image"]
out_cfg = cfg["output"]
device = cfg.get("_resolved_device", "cpu")
width, height = resolve_image_size(img_cfg)
steps = img_cfg.get("num_inference_steps", 9)
guidance = img_cfg.get("guidance_scale", 0.0)
seed = img_cfg.get("seed", -1)
images_per_prompt = max(1, min(10, img_cfg.get("images_per_prompt", 1)))
saved: list[Path] = []
for j in range(images_per_prompt):
variant_offset = index * 100 + j
if seed >= 0:
actual_seed = seed + variant_offset
else:
actual_seed = int(time.time() * 1000) % (2**32) + variant_offset
generator = create_generator(device, actual_seed)
seed_suffix = chr(ord("a") + j) if images_per_prompt > 1 else ""
if images_per_prompt > 1:
print(f" --- 第 {j + 1}/{images_per_prompt} 张 (seed={actual_seed}) ---")
start = time.time()
result = pipe(
prompt=prompt,
height=height,
width=width,
num_inference_steps=steps,
guidance_scale=guidance,
generator=generator,
)
image = result.images[0]
elapsed = time.time() - start
print(f" 生成完成,耗时 {elapsed:.1f}s")
img_path = output_dir / f"{filename_prefix}_{index:02d}{seed_suffix}.png"
image.save(img_path)
saved.append(img_path)
print(f" 已保存: {img_path}")
if out_cfg.get("save_prompts", True):
txt_path = output_dir / f"{filename_prefix}_{index:02d}{seed_suffix}_prompt.txt"
txt_path.write_bytes(prompt.encode("utf-8"))
return saved
def main() -> None:
parser = argparse.ArgumentParser(
description="Z-Image-Turbo 直出图:使用用户给定 prompt不做文本侧处理"
)
parser.add_argument(
"-c", "--config",
default="config.yaml",
help="配置文件路径(默认: config.yaml",
)
parser.add_argument(
"-p", "--prompt",
action="append",
dest="prompts",
metavar="TEXT",
help="prompt 文本;可多次指定以连续生成多张不同 prompt",
)
parser.add_argument(
"-f", "--file",
action="append",
dest="prompt_files",
metavar="PATH",
help="从 UTF-8 文件读取整段 prompt可多次指定",
)
parser.add_argument(
"-o", "--output",
default=None,
help="输出目录(默认: output 下按日期时间分子目录,与 poetry_to_image 一致)",
)
parser.add_argument(
"--flat-output",
action="store_true",
help="将输出直接写入配置中的 output.dir不再追加 日期/时间 子目录",
)
args = parser.parse_args()
cfg = load_config(args.config)
if args.output:
cfg["output"]["dir"] = args.output
elif not args.flat_output:
base = Path(cfg["output"].get("dir", "./output"))
now = datetime.now()
cfg["output"]["dir"] = str(base / now.strftime("%Y-%m-%d") / now.strftime("%H-%M-%S"))
prompts = _collect_prompts(args)
if not prompts:
print("未提供任何 prompt。", file=sys.stderr)
sys.exit(1)
for k, p in enumerate(prompts, 1):
if not p:
print(f"警告: 第 {k} 条 prompt 为空,已跳过。", file=sys.stderr)
prompts = [p for p in prompts if p]
if not prompts:
sys.exit(1)
out_dir = Path(cfg["output"].get("dir", "./output"))
out_dir.mkdir(parents=True, exist_ok=True)
prefix = cfg["output"].get("filename_prefix", "zimg")
img_cfg = cfg["image"]
width, height = resolve_image_size(img_cfg)
preset = img_cfg.get("size_preset", "custom")
print(f"\n输出目录: {out_dir.resolve()}")
print(f"图片尺寸: {width}×{height}" + (f" (预设: {preset})" if preset != "custom" else ""))
print(f"{len(prompts)} 条 prompt将依次原样送模型推理。\n")
pipe = load_pipeline(cfg)
all_saved: list[Path] = []
for i, prompt in enumerate(prompts, 1):
preview = prompt if len(prompt) <= 160 else prompt[:160] + "..."
print(f"[{i}/{len(prompts)}] Prompt:\n{preview}\n")
all_saved.extend(
_generate(
pipe,
prompt,
cfg=cfg,
index=i,
output_dir=out_dir,
filename_prefix=prefix,
)
)
print(f"\n全部完成,共 {len(all_saved)} 个文件。")
if __name__ == "__main__":
main()