prompt.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. from typing import Any, Dict, Optional, List, Set, Tuple
  2. from rich.prompt import Prompt, IntPrompt, Confirm
  3. import typer
  4. import sys
  5. class PromptHandler:
  6. def __init__(self, declared_variables: Dict[str, Tuple[str, Dict[str, Any]]], variable_sets: Dict[str, Dict[str, Any]]):
  7. self._declared = declared_variables
  8. self.variable_sets = variable_sets
  9. @staticmethod
  10. def ask_bool(prompt_text: str, default: bool = False) -> bool:
  11. """Ask a yes/no question, render default in cyan when in a TTY, and
  12. fall back to typer.confirm when not attached to a TTY.
  13. """
  14. if not (sys.stdin.isatty() and sys.stdout.isatty()):
  15. return typer.confirm(prompt_text, default=default)
  16. if default:
  17. indicator = "[cyan]Y[/cyan]/n"
  18. else:
  19. indicator = "y/[cyan]N[/cyan]"
  20. prompt_full = f"{prompt_text} [{indicator}]"
  21. resp = Prompt.ask(prompt_full, default="", show_default=False)
  22. if resp is None or str(resp).strip() == "":
  23. return bool(default)
  24. r = str(resp).strip().lower()
  25. return r[0] in ("y", "1", "t")
  26. @staticmethod
  27. def ask_int(prompt_text: str, default: Optional[int] = None) -> int:
  28. return IntPrompt.ask(prompt_text, default=default, show_default=True)
  29. @staticmethod
  30. def ask_str(prompt_text: str, default: Optional[str] = None, show_default: bool = True) -> str:
  31. return Prompt.ask(prompt_text, default=default, show_default=show_default)
  32. def collect_values(self, used_vars: Set[str], template_defaults: Dict[str, Any] = None, used_subscripts: Dict[str, Set[str]] = None) -> Dict[str, Any]:
  33. """Interactively prompt for values for the variables that appear in the template.
  34. For variables that were declared in `variable_sets` we use their metadata.
  35. For unknown variables, we fall back to a generic prompt.
  36. """
  37. if template_defaults is None:
  38. template_defaults = {}
  39. values: Dict[str, Any] = {}
  40. # Group used vars by their set.
  41. # Iterate through declared variable_sets so the prompt order
  42. # matches the order variables were defined in each set.
  43. set_used_vars: Dict[str, List[str]] = {}
  44. for set_name, set_def in self.variable_sets.items():
  45. vars_map = set_def.get("variables") if isinstance(set_def, dict) and "variables" in set_def else set_def
  46. if not isinstance(vars_map, dict):
  47. continue
  48. for var_name in vars_map.keys():
  49. if var_name in used_vars and var_name in self._declared:
  50. if set_name not in set_used_vars:
  51. set_used_vars[set_name] = []
  52. set_used_vars[set_name].append(var_name)
  53. # If the set name is used as a variable, include the set for prompting
  54. if set_name in used_vars and set_name not in set_used_vars:
  55. set_used_vars[set_name] = []
  56. # Process each set
  57. for set_name, vars_in_set in set_used_vars.items():
  58. # Retrieve per-set definition to pick up the custom prompt if provided
  59. set_def = self.variable_sets.get(set_name, {})
  60. set_prompt = set_def.get("prompt") if isinstance(set_def, dict) else None
  61. typer.secho(f"\n{set_name.title()} Settings", fg=typer.colors.BLUE, bold=True)
  62. def _print_defaults_for_set(vars_list):
  63. # Print each variable and its default value (field name: grey, value: cyan)
  64. for v in vars_list:
  65. meta_info = self._declared[v][1]
  66. display_name = meta_info.get("display_name", v.replace("_", " ").title())
  67. default = self._get_effective_default(v, template_defaults, values)
  68. # If variable is accessed with subscripts, show '(multiple)'
  69. if used_subscripts and v in used_subscripts and used_subscripts[v]:
  70. typer.secho(f"{display_name}: ", fg=typer.colors.BRIGHT_BLACK, nl=False)
  71. typer.secho("(multiple)", fg=typer.colors.CYAN)
  72. else:
  73. typer.secho(f"{display_name}: ", fg=typer.colors.BRIGHT_BLACK, nl=False)
  74. typer.secho(f"{default}", fg=typer.colors.CYAN)
  75. # Decide whether this set is enabled and whether it should be
  76. # customized. Support three modes in the set definition:
  77. # - 'always': True => the set is enabled and we skip the enable
  78. # question (but may still ask to customize values)
  79. # - 'prompt_enable': str => ask this question first to enable the
  80. # set (stores values[set_name] boolean)
  81. # - 'prompt' (existing): when provided, ask whether to customize
  82. # the values. We ask 'prompt_enable' first when present, then
  83. # 'prompt' to decide whether to customize.
  84. set_always = bool(set_def.get('always', False))
  85. set_prompt_enable = set_def.get('prompt_enable')
  86. set_customize_prompt = set_prompt or f"Do you want to change the {set_name.title()} settings?"
  87. if set_always:
  88. enable_set = True
  89. elif set_prompt_enable:
  90. enable_set = self.ask_bool(set_prompt_enable, default=False)
  91. else:
  92. # No explicit enable prompt: fall back to asking the customize prompt
  93. # and treat that as enabling when answered Yes.
  94. enable_set = None
  95. # If we have a definitive enable decision, store it into values
  96. if enable_set is not None:
  97. values[set_name] = enable_set
  98. # If a declared variable exists with the same name, don't prompt it
  99. if set_name in vars_in_set:
  100. vars_in_set = [v for v in vars_in_set if v != set_name]
  101. # If we didn't ask prompt_enable, ask the customize prompt directly
  102. if enable_set is None:
  103. # In this mode we treat the customize prompt as the enable decision.
  104. change_set = self.ask_bool(set_customize_prompt, default=False)
  105. values[set_name] = change_set
  106. if set_name in vars_in_set:
  107. vars_in_set = [v for v in vars_in_set if v != set_name]
  108. if not change_set:
  109. # Use defaults for this set
  110. for var in vars_in_set:
  111. meta_info = self._declared[var][1]
  112. default = self._get_effective_default(var, template_defaults, values)
  113. values[var] = default
  114. continue
  115. # If we had an enable_set (True/False) and it is False, skip customizing
  116. if enable_set is not None and not enable_set:
  117. for var in vars_in_set:
  118. meta_info = self._declared[var][1]
  119. default = self._get_effective_default(var, template_defaults, values)
  120. values[var] = default
  121. continue
  122. # At this point the set is enabled. Print defaults now (only after
  123. # enabling) so the user sees current values before customizing.
  124. _print_defaults_for_set(vars_in_set)
  125. # If we have asked prompt_enable earlier (and the set is enabled),
  126. # now ask whether to customize. For 'always' sets we still ask the
  127. # customize prompt.
  128. if set_prompt_enable or set_always:
  129. change_set = self.ask_bool(set_customize_prompt, default=False)
  130. if not change_set:
  131. for var in vars_in_set:
  132. meta_info = self._declared[var][1]
  133. default = self._get_effective_default(var, template_defaults, values)
  134. values[var] = default
  135. continue
  136. # Prompt for each variable in the set
  137. for var in vars_in_set:
  138. meta_info = self._declared[var][1]
  139. display_name = meta_info.get("display_name", var.replace("_", " ").title())
  140. vtype = meta_info.get("type", "str")
  141. prompt = meta_info.get("prompt", f"Enter {display_name}")
  142. default = self._get_effective_default(var, template_defaults, values)
  143. # Build prompt text and rely on show_default to display the default value
  144. prompt_text = f"{prompt}"
  145. # If variable is accessed with subscripts in the template, always prompt for each key and store as dict
  146. subs = used_subscripts.get(var, set()) if used_subscripts else set()
  147. if subs:
  148. # Print all default values for subscripted keys before prompting
  149. for k in subs:
  150. key_default = None
  151. if isinstance(default, dict):
  152. key_default = default.get(k)
  153. elif default is not None:
  154. key_default = default
  155. typer.secho(f"{display_name}['{k}']: ", fg=typer.colors.BRIGHT_BLACK, nl=False)
  156. typer.secho(f"{key_default}", fg=typer.colors.CYAN)
  157. result_map = {}
  158. for k in subs:
  159. kval = Prompt.ask(f"Value for {display_name}['{k}']:", default=str(default.get(k)) if isinstance(default, dict) and default.get(k) is not None else None, show_default=True)
  160. result_map[k] = self._guess_and_cast(kval)
  161. values[var] = result_map
  162. continue
  163. if vtype == "bool":
  164. # Normalize default to bool
  165. bool_default = False
  166. if isinstance(default, bool):
  167. bool_default = default
  168. elif isinstance(default, str):
  169. bool_default = default.lower() in ("true", "1", "yes")
  170. elif isinstance(default, int):
  171. bool_default = default != 0
  172. val = self.ask_bool(prompt_text, default=bool_default)
  173. elif vtype == "int":
  174. # Use IntPrompt to validate and parse integers; show default if present
  175. int_default = None
  176. if isinstance(default, int):
  177. int_default = default
  178. elif isinstance(default, str) and default.isdigit():
  179. int_default = int(default)
  180. val = IntPrompt.ask(prompt_text, default=int_default, show_default=True)
  181. else:
  182. # Use Prompt for string input and show default
  183. str_default = str(default) if default is not None else None
  184. val = Prompt.ask(prompt_text, default=str_default, show_default=True)
  185. # Handle collection types: arrays and maps
  186. if vtype in ("array", "list"):
  187. values[var] = self.prompt_array(var, meta_info, default)
  188. continue
  189. if vtype in ("map", "dict"):
  190. # If the template indexes this variable with specific keys, prompt per-key
  191. subs = used_subscripts.get(var, set()) if used_subscripts else set()
  192. if subs:
  193. # Prompt for each accessed key; allow single scalar default to apply to all
  194. result_map = {}
  195. # If default is a scalar, ask whether to expand it to accessed keys
  196. if not isinstance(default, dict) and default is not None:
  197. use_single = self.ask_bool(f"Use single value {default} for all {display_name} keys?", default=True)
  198. if use_single:
  199. for k in subs:
  200. result_map[k] = default
  201. values[var] = result_map
  202. continue
  203. # Otherwise prompt per key or use metadata keys when present
  204. keys_meta = meta_info.get("keys")
  205. for k in subs:
  206. if isinstance(keys_meta, dict) and k in keys_meta:
  207. # reuse metadata prompt
  208. kmeta = keys_meta[k]
  209. result_map[k] = self.prompt_scalar(k, kmeta, kmeta.get("default"))
  210. else:
  211. # generic prompt
  212. kval = self.ask_str(f"Value for {display_name}['{k}']:")
  213. result_map[k] = self._guess_and_cast(kval)
  214. values[var] = result_map
  215. continue
  216. # Fallback to full map prompting
  217. values[var] = self.prompt_map(var, meta_info, default)
  218. continue
  219. # store scalar/canonicalized value
  220. values[var] = self._cast_value_from_input(val, vtype)
  221. # Handle unknown variables. If a variable was already set (for
  222. # example by the set-level prompt mapping into `values[set_name]`),
  223. # don't prompt for it again.
  224. for var in used_vars:
  225. if var not in self._declared and var not in values:
  226. prompt_text = f"Value for '{var}':"
  227. val = Prompt.ask(prompt_text, default="", show_default=False)
  228. values[var] = self._guess_and_cast(val)
  229. return values
  230. def _get_effective_default(self, var_name: str, template_defaults: Dict[str, Any], current_values: Dict[str, Any]):
  231. # Prefer template-provided default, else declared metadata default
  232. meta_info = self._declared.get(var_name, ({}, {}))[1] if var_name in self._declared else {}
  233. candidate = None
  234. if template_defaults and var_name in template_defaults:
  235. candidate = template_defaults[var_name]
  236. else:
  237. candidate = meta_info.get("default") if isinstance(meta_info, dict) else None
  238. # If candidate names another variable and that variable has already
  239. # been provided by the user, use that value.
  240. if isinstance(candidate, str) and candidate in current_values:
  241. return current_values[candidate]
  242. # Otherwise, try to resolve identifier references to declared defaults
  243. if isinstance(candidate, str) and candidate in self._declared:
  244. decl_def = self._declared[candidate][1].get("default")
  245. if decl_def is not None:
  246. return decl_def
  247. return candidate
  248. def prompt_scalar(self, var_name: str, meta_info: Dict[str, Any], default_val: Any) -> Any:
  249. display_name = meta_info.get("display_name", var_name.replace("_", " ").title())
  250. vtype = meta_info.get("type", "str")
  251. prompt = meta_info.get("prompt", f"Enter {display_name}")
  252. if vtype == "bool":
  253. bool_default = False
  254. if isinstance(default_val, bool):
  255. bool_default = default_val
  256. elif isinstance(default_val, str):
  257. bool_default = default_val.lower() in ("true", "1", "yes")
  258. elif isinstance(default_val, int):
  259. bool_default = default_val != 0
  260. return self.ask_bool(prompt, default=bool_default)
  261. if vtype == "int":
  262. int_default = None
  263. if isinstance(default_val, int):
  264. int_default = default_val
  265. elif isinstance(default_val, str) and default_val.isdigit():
  266. int_default = int(default_val)
  267. return self.ask_int(prompt, default=int_default)
  268. str_default = str(default_val) if default_val is not None else None
  269. return self.ask_str(prompt, default=str_default, show_default=True)
  270. def prompt_array(self, var_name: str, meta_info: Dict[str, Any], default_val: Any) -> Any:
  271. display_name = meta_info.get("display_name", var_name.replace("_", " ").title())
  272. item_type = meta_info.get("item_type", "str")
  273. item_prompt = meta_info.get("item_prompt", f"Enter {display_name} item")
  274. default_list = default_val if isinstance(default_val, list) else []
  275. default_count = len(default_list) if default_list else 0
  276. count = self.ask_int(f"How many entries for {display_name}?", default=default_count or 1)
  277. arr = []
  278. for i in range(count):
  279. item_default = default_list[i] if i < len(default_list) else None
  280. item_prompt_text = f"{item_prompt} [{i}]"
  281. if item_type == "int":
  282. int_d = item_default if isinstance(item_default, int) else (int(item_default) if isinstance(item_default, str) and str(item_default).isdigit() else None)
  283. item_val = self.ask_int(item_prompt_text, default=int_d)
  284. elif item_type == "bool":
  285. item_bool_d = self._cast_str_to_bool(item_default)
  286. item_val = self.ask_bool(item_prompt_text, default=item_bool_d)
  287. else:
  288. item_str_d = str(item_default) if item_default is not None else None
  289. item_val = self.ask_str(item_prompt_text, default=item_str_d, show_default=True)
  290. arr.append(self._cast_value_from_input(item_val, item_type))
  291. return arr
  292. def prompt_map(self, var_name: str, meta_info: Dict[str, Any], default_val: Any) -> Any:
  293. display_name = meta_info.get("display_name", var_name.replace("_", " ").title())
  294. keys_meta = meta_info.get("keys")
  295. result_map = {}
  296. if isinstance(keys_meta, dict):
  297. for key_name, kmeta in keys_meta.items():
  298. kdisplay = kmeta.get("display_name", f"{display_name}['{key_name}']")
  299. ktype = kmeta.get("type", "str")
  300. kdefault = kmeta.get("default") if "default" in kmeta else (default_val.get(key_name) if isinstance(default_val, dict) and key_name in default_val else None)
  301. kprompt = kmeta.get("prompt", f"Enter value for {kdisplay}")
  302. if ktype == "int":
  303. kd = kdefault if isinstance(kdefault, int) else (int(kdefault) if isinstance(kdefault, str) and str(kdefault).isdigit() else None)
  304. kval = self.ask_int(kprompt, default=kd)
  305. elif ktype == "bool":
  306. kval = self.ask_bool(kprompt, default=self._cast_str_to_bool(kdefault))
  307. else:
  308. kval = self.ask_str(kprompt, default=str(kdefault) if kdefault is not None else None, show_default=True)
  309. result_map[key_name] = self._cast_value_from_input(kval, ktype)
  310. return result_map
  311. if isinstance(default_val, dict) and len(default_val) > 0:
  312. for key_name, kdefault in default_val.items():
  313. kprompt = f"Enter value for {display_name}['{key_name}']"
  314. kval = self.ask_str(kprompt, default=str(kdefault) if kdefault is not None else None, show_default=True)
  315. result_map[key_name] = self._guess_and_cast(kval)
  316. return result_map
  317. count = self.ask_int(f"How many named entries for {display_name}?", default=1)
  318. for i in range(count):
  319. key_name = self.ask_str(f"Key name [{i}]", default=None, show_default=False)
  320. kval = self.ask_str(f"Value for {display_name}['{key_name}']:", default=None, show_default=False)
  321. result_map[key_name] = self._guess_and_cast(kval)
  322. return result_map
  323. @staticmethod
  324. def _cast_str_to_bool(s):
  325. if isinstance(s, bool):
  326. return s
  327. if isinstance(s, int):
  328. return s != 0
  329. if isinstance(s, str):
  330. return s.lower() in ("true", "1", "yes")
  331. return False
  332. @staticmethod
  333. def _cast_value_from_input(raw, vtype):
  334. if vtype == "int":
  335. try:
  336. return int(raw)
  337. except Exception:
  338. return raw
  339. if vtype == "bool":
  340. return PromptHandler._cast_str_to_bool(raw)
  341. return raw
  342. @staticmethod
  343. def _guess_and_cast(raw):
  344. s = raw if not isinstance(raw, str) else raw.strip()
  345. if s == "":
  346. return raw
  347. if isinstance(s, str) and s.isdigit():
  348. return PromptHandler._cast_value_from_input(s, "int")
  349. if isinstance(s, str) and s.lower() in ("true", "false", "yes", "no", "1", "0", "t", "f"):
  350. return PromptHandler._cast_value_from_input(s, "bool")
  351. return PromptHandler._cast_value_from_input(s, "str")
  352. def prompt_scalar(self, var_name: str, meta_info: Dict[str, Any], default_val: Any) -> Any:
  353. display_name = meta_info.get("display_name", var_name.replace("_", " ").title())
  354. vtype = meta_info.get("type", "str")
  355. prompt = meta_info.get("prompt", f"Enter {display_name}")
  356. if vtype == "bool":
  357. bool_default = False
  358. if isinstance(default_val, bool):
  359. bool_default = default_val
  360. elif isinstance(default_val, str):
  361. bool_default = default_val.lower() in ("true", "1", "yes")
  362. elif isinstance(default_val, int):
  363. bool_default = default_val != 0
  364. return self.ask_bool(prompt, default=bool_default)
  365. if vtype == "int":
  366. int_default = None
  367. if isinstance(default_val, int):
  368. int_default = default_val
  369. elif isinstance(default_val, str) and default_val.isdigit():
  370. int_default = int(default_val)
  371. return self.ask_int(prompt, default=int_default)
  372. str_default = str(default_val) if default_val is not None else None
  373. return self.ask_str(prompt, default=str_default, show_default=True)
  374. def prompt_array(self, var_name: str, meta_info: Dict[str, Any], default_val: Any) -> Any:
  375. display_name = meta_info.get("display_name", var_name.replace("_", " ").title())
  376. item_type = meta_info.get("item_type", "str")
  377. item_prompt = meta_info.get("item_prompt", f"Enter {display_name} item")
  378. default_list = default_val if isinstance(default_val, list) else []
  379. default_count = len(default_list) if default_list else 0
  380. count = self.ask_int(f"How many entries for {display_name}?", default=default_count or 1)
  381. arr = []
  382. for i in range(count):
  383. item_default = default_list[i] if i < len(default_list) else None
  384. item_prompt_text = f"{item_prompt} [{i}]"
  385. if item_type == "int":
  386. int_d = item_default if isinstance(item_default, int) else (int(item_default) if isinstance(item_default, str) and str(item_default).isdigit() else None)
  387. item_val = self.ask_int(item_prompt_text, default=int_d)
  388. elif item_type == "bool":
  389. item_bool_d = self._cast_str_to_bool(item_default)
  390. item_val = self.ask_bool(item_prompt_text, default=item_bool_d)
  391. else:
  392. item_str_d = str(item_default) if item_default is not None else None
  393. item_val = self.ask_str(item_prompt_text, default=item_str_d, show_default=True)
  394. arr.append(self._cast_value_from_input(item_val, item_type))
  395. return arr
  396. def prompt_map(self, var_name: str, meta_info: Dict[str, Any], default_val: Any) -> Any:
  397. display_name = meta_info.get("display_name", var_name.replace("_", " ").title())
  398. keys_meta = meta_info.get("keys")
  399. result_map = {}
  400. if isinstance(keys_meta, dict):
  401. for key_name, kmeta in keys_meta.items():
  402. kdisplay = kmeta.get("display_name", f"{display_name}['{key_name}']")
  403. ktype = kmeta.get("type", "str")
  404. kdefault = kmeta.get("default") if "default" in kmeta else (default_val.get(key_name) if isinstance(default_val, dict) and key_name in default_val else None)
  405. kprompt = kmeta.get("prompt", f"Enter value for {kdisplay}")
  406. if ktype == "int":
  407. kd = kdefault if isinstance(kdefault, int) else (int(kdefault) if isinstance(kdefault, str) and str(kdefault).isdigit() else None)
  408. kval = self.ask_int(kprompt, default=kd)
  409. elif ktype == "bool":
  410. kval = self.ask_bool(kprompt, default=self._cast_str_to_bool(kdefault))
  411. else:
  412. kval = self.ask_str(kprompt, default=str(kdefault) if kdefault is not None else None, show_default=True)
  413. result_map[key_name] = self._cast_value_from_input(kval, ktype)
  414. return result_map
  415. if isinstance(default_val, dict) and len(default_val) > 0:
  416. for key_name, kdefault in default_val.items():
  417. kprompt = f"Enter value for {display_name}['{key_name}']"
  418. kval = self.ask_str(kprompt, default=str(kdefault) if kdefault is not None else None, show_default=True)
  419. result_map[key_name] = self._guess_and_cast(kval)
  420. return result_map
  421. count = self.ask_int(f"How many named entries for {display_name}?", default=1)
  422. for i in range(count):
  423. key_name = self.ask_str(f"Key name [{i}]", default=None, show_default=False)
  424. kval = self.ask_str(f"Value for {display_name}['{key_name}']:", default=None, show_default=False)
  425. result_map[key_name] = self._guess_and_cast(kval)
  426. return result_map
  427. @staticmethod
  428. def _cast_str_to_bool(s):
  429. if isinstance(s, bool):
  430. return s
  431. if isinstance(s, int):
  432. return s != 0
  433. if isinstance(s, str):
  434. return s.lower() in ("true", "1", "yes")
  435. return False
  436. @staticmethod
  437. def _cast_value_from_input(raw, vtype):
  438. if vtype == "int":
  439. try:
  440. return int(raw)
  441. except Exception:
  442. return raw
  443. if vtype == "bool":
  444. return PromptHandler._cast_str_to_bool(raw)
  445. return raw
  446. @staticmethod
  447. def _guess_and_cast(raw):
  448. s = raw if not isinstance(raw, str) else raw.strip()
  449. if s == "":
  450. return raw
  451. if isinstance(s, str) and s.isdigit():
  452. return PromptHandler._cast_value_from_input(s, "int")
  453. if isinstance(s, str) and s.lower() in ("true", "false", "yes", "no", "1", "0", "t", "f"):
  454. return PromptHandler._cast_value_from_input(s, "bool")
  455. return PromptHandler._cast_value_from_input(s, "str")