205 lines
6.3 KiB
Python
205 lines
6.3 KiB
Python
"""
|
||
直接使用用户输入的 prompt 调用本地 Z-Image-Turbo 出图。
|
||
|
||
复用 config.yaml 中的 image / lora / output 等推理配置;
|
||
不对 prompt 做 LLM 改写,不自动拼接 LoRA 触发词(触发词请自行写进 prompt)。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import sys
|
||
import time
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
|
||
from poetry_to_image import (
|
||
create_generator,
|
||
load_config,
|
||
load_pipeline,
|
||
resolve_image_size,
|
||
)
|
||
|
||
|
||
def _read_prompt_from_file(path: Path) -> str:
|
||
"""按 UTF-8 原样读取文件,不做 strip 或换行规范化以外的解码。"""
|
||
return path.read_bytes().decode("utf-8")
|
||
|
||
|
||
def _collect_prompts(args: argparse.Namespace) -> list[str]:
|
||
prompts: list[str] = []
|
||
if args.prompts:
|
||
prompts.extend(args.prompts)
|
||
for fp in args.prompt_files or []:
|
||
p = Path(fp)
|
||
if not p.is_file():
|
||
print(f"错误: 文件不存在: {p}", file=sys.stderr)
|
||
sys.exit(1)
|
||
prompts.append(_read_prompt_from_file(p))
|
||
if not prompts and not sys.stdin.isatty():
|
||
prompts.append(sys.stdin.buffer.read().decode("utf-8"))
|
||
if not prompts and sys.stdin.isatty():
|
||
print("请输入 prompt(空行结束):")
|
||
lines: list[str] = []
|
||
while True:
|
||
try:
|
||
line = input()
|
||
except EOFError:
|
||
break
|
||
if line == "":
|
||
break
|
||
lines.append(line)
|
||
text = "\n".join(lines)
|
||
if text:
|
||
prompts.append(text)
|
||
return prompts
|
||
|
||
|
||
def _generate(
|
||
pipe,
|
||
prompt: str,
|
||
*,
|
||
cfg: dict,
|
||
index: int,
|
||
output_dir: Path,
|
||
filename_prefix: str,
|
||
) -> list[Path]:
|
||
"""对单条 prompt 出图;prompt 字符串原样传入 pipeline。"""
|
||
img_cfg = cfg["image"]
|
||
out_cfg = cfg["output"]
|
||
device = cfg.get("_resolved_device", "cpu")
|
||
|
||
width, height = resolve_image_size(img_cfg)
|
||
steps = img_cfg.get("num_inference_steps", 9)
|
||
guidance = img_cfg.get("guidance_scale", 0.0)
|
||
seed = img_cfg.get("seed", -1)
|
||
images_per_prompt = max(1, min(10, img_cfg.get("images_per_prompt", 1)))
|
||
|
||
saved: list[Path] = []
|
||
for j in range(images_per_prompt):
|
||
variant_offset = index * 100 + j
|
||
if seed >= 0:
|
||
actual_seed = seed + variant_offset
|
||
else:
|
||
actual_seed = int(time.time() * 1000) % (2**32) + variant_offset
|
||
generator = create_generator(device, actual_seed)
|
||
seed_suffix = chr(ord("a") + j) if images_per_prompt > 1 else ""
|
||
|
||
if images_per_prompt > 1:
|
||
print(f" --- 第 {j + 1}/{images_per_prompt} 张 (seed={actual_seed}) ---")
|
||
|
||
start = time.time()
|
||
result = pipe(
|
||
prompt=prompt,
|
||
height=height,
|
||
width=width,
|
||
num_inference_steps=steps,
|
||
guidance_scale=guidance,
|
||
generator=generator,
|
||
)
|
||
image = result.images[0]
|
||
elapsed = time.time() - start
|
||
print(f" 生成完成,耗时 {elapsed:.1f}s")
|
||
|
||
img_path = output_dir / f"{filename_prefix}_{index:02d}{seed_suffix}.png"
|
||
image.save(img_path)
|
||
saved.append(img_path)
|
||
print(f" 已保存: {img_path}")
|
||
|
||
if out_cfg.get("save_prompts", True):
|
||
txt_path = output_dir / f"{filename_prefix}_{index:02d}{seed_suffix}_prompt.txt"
|
||
txt_path.write_bytes(prompt.encode("utf-8"))
|
||
|
||
return saved
|
||
|
||
|
||
def main() -> None:
|
||
parser = argparse.ArgumentParser(
|
||
description="Z-Image-Turbo 直出图:使用用户给定 prompt,不做文本侧处理"
|
||
)
|
||
parser.add_argument(
|
||
"-c", "--config",
|
||
default="config.yaml",
|
||
help="配置文件路径(默认: config.yaml)",
|
||
)
|
||
parser.add_argument(
|
||
"-p", "--prompt",
|
||
action="append",
|
||
dest="prompts",
|
||
metavar="TEXT",
|
||
help="prompt 文本;可多次指定以连续生成多张不同 prompt",
|
||
)
|
||
parser.add_argument(
|
||
"-f", "--file",
|
||
action="append",
|
||
dest="prompt_files",
|
||
metavar="PATH",
|
||
help="从 UTF-8 文件读取整段 prompt;可多次指定",
|
||
)
|
||
parser.add_argument(
|
||
"-o", "--output",
|
||
default=None,
|
||
help="输出目录(默认: output 下按日期时间分子目录,与 poetry_to_image 一致)",
|
||
)
|
||
parser.add_argument(
|
||
"--flat-output",
|
||
action="store_true",
|
||
help="将输出直接写入配置中的 output.dir,不再追加 日期/时间 子目录",
|
||
)
|
||
args = parser.parse_args()
|
||
|
||
cfg = load_config(args.config)
|
||
if args.output:
|
||
cfg["output"]["dir"] = args.output
|
||
elif not args.flat_output:
|
||
base = Path(cfg["output"].get("dir", "./output"))
|
||
now = datetime.now()
|
||
cfg["output"]["dir"] = str(base / now.strftime("%Y-%m-%d") / now.strftime("%H-%M-%S"))
|
||
|
||
prompts = _collect_prompts(args)
|
||
if not prompts:
|
||
print("未提供任何 prompt。", file=sys.stderr)
|
||
sys.exit(1)
|
||
|
||
for k, p in enumerate(prompts, 1):
|
||
if not p:
|
||
print(f"警告: 第 {k} 条 prompt 为空,已跳过。", file=sys.stderr)
|
||
|
||
prompts = [p for p in prompts if p]
|
||
if not prompts:
|
||
sys.exit(1)
|
||
|
||
out_dir = Path(cfg["output"].get("dir", "./output"))
|
||
out_dir.mkdir(parents=True, exist_ok=True)
|
||
prefix = cfg["output"].get("filename_prefix", "zimg")
|
||
|
||
img_cfg = cfg["image"]
|
||
width, height = resolve_image_size(img_cfg)
|
||
preset = img_cfg.get("size_preset", "custom")
|
||
print(f"\n输出目录: {out_dir.resolve()}")
|
||
print(f"图片尺寸: {width}×{height}" + (f" (预设: {preset})" if preset != "custom" else ""))
|
||
print(f"共 {len(prompts)} 条 prompt,将依次原样送模型推理。\n")
|
||
|
||
pipe = load_pipeline(cfg)
|
||
|
||
all_saved: list[Path] = []
|
||
for i, prompt in enumerate(prompts, 1):
|
||
preview = prompt if len(prompt) <= 160 else prompt[:160] + "..."
|
||
print(f"[{i}/{len(prompts)}] Prompt:\n{preview}\n")
|
||
all_saved.extend(
|
||
_generate(
|
||
pipe,
|
||
prompt,
|
||
cfg=cfg,
|
||
index=i,
|
||
output_dir=out_dir,
|
||
filename_prefix=prefix,
|
||
)
|
||
)
|
||
|
||
print(f"\n全部完成,共 {len(all_saved)} 个文件。")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|