优化脚本,添加另一种风格的prompt
This commit is contained in:
204
prompt_to_image.py
Normal file
204
prompt_to_image.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
直接使用用户输入的 prompt 调用本地 Z-Image-Turbo 出图。
|
||||
|
||||
复用 config.yaml 中的 image / lora / output 等推理配置;
|
||||
不对 prompt 做 LLM 改写,不自动拼接 LoRA 触发词(触发词请自行写进 prompt)。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from poetry_to_image import (
|
||||
create_generator,
|
||||
load_config,
|
||||
load_pipeline,
|
||||
resolve_image_size,
|
||||
)
|
||||
|
||||
|
||||
def _read_prompt_from_file(path: Path) -> str:
|
||||
"""按 UTF-8 原样读取文件,不做 strip 或换行规范化以外的解码。"""
|
||||
return path.read_bytes().decode("utf-8")
|
||||
|
||||
|
||||
def _collect_prompts(args: argparse.Namespace) -> list[str]:
|
||||
prompts: list[str] = []
|
||||
if args.prompts:
|
||||
prompts.extend(args.prompts)
|
||||
for fp in args.prompt_files or []:
|
||||
p = Path(fp)
|
||||
if not p.is_file():
|
||||
print(f"错误: 文件不存在: {p}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
prompts.append(_read_prompt_from_file(p))
|
||||
if not prompts and not sys.stdin.isatty():
|
||||
prompts.append(sys.stdin.buffer.read().decode("utf-8"))
|
||||
if not prompts and sys.stdin.isatty():
|
||||
print("请输入 prompt(空行结束):")
|
||||
lines: list[str] = []
|
||||
while True:
|
||||
try:
|
||||
line = input()
|
||||
except EOFError:
|
||||
break
|
||||
if line == "":
|
||||
break
|
||||
lines.append(line)
|
||||
text = "\n".join(lines)
|
||||
if text:
|
||||
prompts.append(text)
|
||||
return prompts
|
||||
|
||||
|
||||
def _generate(
|
||||
pipe,
|
||||
prompt: str,
|
||||
*,
|
||||
cfg: dict,
|
||||
index: int,
|
||||
output_dir: Path,
|
||||
filename_prefix: str,
|
||||
) -> list[Path]:
|
||||
"""对单条 prompt 出图;prompt 字符串原样传入 pipeline。"""
|
||||
img_cfg = cfg["image"]
|
||||
out_cfg = cfg["output"]
|
||||
device = cfg.get("_resolved_device", "cpu")
|
||||
|
||||
width, height = resolve_image_size(img_cfg)
|
||||
steps = img_cfg.get("num_inference_steps", 9)
|
||||
guidance = img_cfg.get("guidance_scale", 0.0)
|
||||
seed = img_cfg.get("seed", -1)
|
||||
images_per_prompt = max(1, min(10, img_cfg.get("images_per_prompt", 1)))
|
||||
|
||||
saved: list[Path] = []
|
||||
for j in range(images_per_prompt):
|
||||
variant_offset = index * 100 + j
|
||||
if seed >= 0:
|
||||
actual_seed = seed + variant_offset
|
||||
else:
|
||||
actual_seed = int(time.time() * 1000) % (2**32) + variant_offset
|
||||
generator = create_generator(device, actual_seed)
|
||||
seed_suffix = chr(ord("a") + j) if images_per_prompt > 1 else ""
|
||||
|
||||
if images_per_prompt > 1:
|
||||
print(f" --- 第 {j + 1}/{images_per_prompt} 张 (seed={actual_seed}) ---")
|
||||
|
||||
start = time.time()
|
||||
result = pipe(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=steps,
|
||||
guidance_scale=guidance,
|
||||
generator=generator,
|
||||
)
|
||||
image = result.images[0]
|
||||
elapsed = time.time() - start
|
||||
print(f" 生成完成,耗时 {elapsed:.1f}s")
|
||||
|
||||
img_path = output_dir / f"{filename_prefix}_{index:02d}{seed_suffix}.png"
|
||||
image.save(img_path)
|
||||
saved.append(img_path)
|
||||
print(f" 已保存: {img_path}")
|
||||
|
||||
if out_cfg.get("save_prompts", True):
|
||||
txt_path = output_dir / f"{filename_prefix}_{index:02d}{seed_suffix}_prompt.txt"
|
||||
txt_path.write_bytes(prompt.encode("utf-8"))
|
||||
|
||||
return saved
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Z-Image-Turbo 直出图:使用用户给定 prompt,不做文本侧处理"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c", "--config",
|
||||
default="config.yaml",
|
||||
help="配置文件路径(默认: config.yaml)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p", "--prompt",
|
||||
action="append",
|
||||
dest="prompts",
|
||||
metavar="TEXT",
|
||||
help="prompt 文本;可多次指定以连续生成多张不同 prompt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f", "--file",
|
||||
action="append",
|
||||
dest="prompt_files",
|
||||
metavar="PATH",
|
||||
help="从 UTF-8 文件读取整段 prompt;可多次指定",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o", "--output",
|
||||
default=None,
|
||||
help="输出目录(默认: output 下按日期时间分子目录,与 poetry_to_image 一致)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flat-output",
|
||||
action="store_true",
|
||||
help="将输出直接写入配置中的 output.dir,不再追加 日期/时间 子目录",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
cfg = load_config(args.config)
|
||||
if args.output:
|
||||
cfg["output"]["dir"] = args.output
|
||||
elif not args.flat_output:
|
||||
base = Path(cfg["output"].get("dir", "./output"))
|
||||
now = datetime.now()
|
||||
cfg["output"]["dir"] = str(base / now.strftime("%Y-%m-%d") / now.strftime("%H-%M-%S"))
|
||||
|
||||
prompts = _collect_prompts(args)
|
||||
if not prompts:
|
||||
print("未提供任何 prompt。", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
for k, p in enumerate(prompts, 1):
|
||||
if not p:
|
||||
print(f"警告: 第 {k} 条 prompt 为空,已跳过。", file=sys.stderr)
|
||||
|
||||
prompts = [p for p in prompts if p]
|
||||
if not prompts:
|
||||
sys.exit(1)
|
||||
|
||||
out_dir = Path(cfg["output"].get("dir", "./output"))
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
prefix = cfg["output"].get("filename_prefix", "zimg")
|
||||
|
||||
img_cfg = cfg["image"]
|
||||
width, height = resolve_image_size(img_cfg)
|
||||
preset = img_cfg.get("size_preset", "custom")
|
||||
print(f"\n输出目录: {out_dir.resolve()}")
|
||||
print(f"图片尺寸: {width}×{height}" + (f" (预设: {preset})" if preset != "custom" else ""))
|
||||
print(f"共 {len(prompts)} 条 prompt,将依次原样送模型推理。\n")
|
||||
|
||||
pipe = load_pipeline(cfg)
|
||||
|
||||
all_saved: list[Path] = []
|
||||
for i, prompt in enumerate(prompts, 1):
|
||||
preview = prompt if len(prompt) <= 160 else prompt[:160] + "..."
|
||||
print(f"[{i}/{len(prompts)}] Prompt:\n{preview}\n")
|
||||
all_saved.extend(
|
||||
_generate(
|
||||
pipe,
|
||||
prompt,
|
||||
cfg=cfg,
|
||||
index=i,
|
||||
output_dir=out_dir,
|
||||
filename_prefix=prefix,
|
||||
)
|
||||
)
|
||||
|
||||
print(f"\n全部完成,共 {len(all_saved)} 个文件。")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user