| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217 |
- from __future__ import annotations
- from typing import Dict, Any, List, Callable
- import logging
- from rich.console import Console
- from rich.prompt import Prompt, Confirm, IntPrompt
- from rich.table import Table
- from .display import DisplayManager
- from .variables import Variable, VariableCollection
- logger = logging.getLogger(__name__)
- # ---------------------------
- # SECTION: PromptHandler Class
- # ---------------------------
- class PromptHandler:
- """Simple interactive prompt handler for collecting template variables."""
- def __init__(self) -> None:
- self.console = Console()
- self.display = DisplayManager()
- # --------------------------
- # SECTION: Public Methods
- # --------------------------
- def collect_variables(self, variables: VariableCollection) -> dict[str, Any]:
- """Collect values for variables by iterating through sections.
-
- Args:
- variables: VariableCollection with organized sections and variables
-
- Returns:
- Dict of variable names to collected values
- """
- if not Confirm.ask("Customize any settings?", default=False):
- logger.info("User opted to keep all default values")
- return {}
- collected: Dict[str, Any] = {}
- # Process each section
- for section_key, section in variables.get_sections().items():
- if not section.variables:
- continue
- # Check if dependencies are satisfied
- if not variables.is_section_satisfied(section_key):
- # Get list of unsatisfied dependencies for better user feedback
- unsatisfied = [dep for dep in section.needs if not variables.is_section_satisfied(dep)]
- dep_names = ", ".join(unsatisfied) if unsatisfied else "unknown"
- self.console.print(
- f"\n[dim]⊘ {section.title} (skipped - requires {dep_names} to be enabled)[/dim]"
- )
- logger.debug(f"Skipping section '{section_key}' - dependencies not satisfied: {dep_names}")
- continue
- # Always show section header first
- self.display.display_section_header(section.title, section.description)
- # Handle section toggle - skip for required sections
- if section.required:
- # Required sections are always processed, no toggle prompt needed
- logger.debug(f"Processing required section '{section.key}' without toggle prompt")
- elif section.toggle:
- toggle_var = section.variables.get(section.toggle)
- if toggle_var:
- # Use description for prompt if available, otherwise use title
- prompt_text = section.description if section.description else f"Enable {section.title}?"
- current_value = toggle_var.get_typed_value()
- new_value = self._prompt_bool(prompt_text, current_value)
-
- if new_value != current_value:
- collected[toggle_var.name] = new_value
- toggle_var.value = new_value
-
- # Use section's native is_enabled() method
- if not section.is_enabled():
- continue
- # Collect variables in this section
- for var_name, variable in section.variables.items():
- # Skip toggle variable (already handled)
- if section.toggle and var_name == section.toggle:
- continue
-
- current_value = variable.get_typed_value()
- # Pass section.required so _prompt_variable can enforce required inputs
- new_value = self._prompt_variable(variable, required=section.required)
-
- if new_value != current_value:
- collected[var_name] = new_value
- variable.value = new_value
- logger.info(f"Variable collection completed. Collected {len(collected)} values")
- return collected
- # !SECTION
- # ---------------------------
- # SECTION: Private Methods
- # ---------------------------
- def _prompt_variable(self, variable: Variable, required: bool = False) -> Any:
- """Prompt for a single variable value based on its type."""
- logger.debug(f"Prompting for variable '{variable.name}' (type: {variable.type})")
-
- # Use variable's native methods for prompt text and default value
- prompt_text = variable.get_prompt_text()
- default_value = variable.get_normalized_default()
- # If variable is required and there's no default, mark it in the prompt
- # (but skip this for autogenerated variables since they can be empty)
- if required and default_value is None and not variable.autogenerated:
- prompt_text = f"{prompt_text} [bold red]*required[/bold red]"
- handler = self._get_prompt_handler(variable)
- # Add validation hint (includes both extra text and enum options)
- hint = variable.get_validation_hint()
- if hint:
- prompt_text = f"{prompt_text} [dim]{hint}[/dim]"
- while True:
- try:
- raw = handler(prompt_text, default_value)
- # Convert/validate the user's input using the Variable conversion
- converted = variable.convert(raw)
- # Allow empty values for autogenerated variables
- if variable.autogenerated and (converted is None or (isinstance(converted, str) and converted == "")):
- return None # Return None to indicate auto-generation should happen
-
- # If this variable is required, do not accept None/empty values
- if required and (converted is None or (isinstance(converted, str) and converted == "")):
- raise ValueError("value cannot be empty for required variable")
- # Return the converted value (caller will update variable.value)
- return converted
- except ValueError as exc:
- # Conversion/validation failed — show a consistent error message and retry
- self._show_validation_error(str(exc))
- except Exception as e:
- # Unexpected error — log and retry using the stored (unconverted) value
- logger.error(f"Error prompting for variable '{variable.name}': {str(e)}")
- default_value = variable.value
- handler = self._get_prompt_handler(variable)
- def _get_prompt_handler(self, variable: Variable) -> Callable:
- """Return the prompt function for a variable type."""
- handlers = {
- "bool": self._prompt_bool,
- "int": self._prompt_int,
- # For enum prompts we pass the variable.extra through so options and extra
- # can be combined into a single inline hint.
- "enum": lambda text, default: self._prompt_enum(text, variable.options or [], default, extra=getattr(variable, 'extra', None)),
- }
- return handlers.get(variable.type, lambda text, default: self._prompt_string(text, default, is_sensitive=variable.sensitive))
- def _show_validation_error(self, message: str) -> None:
- """Display validation feedback consistently."""
- self.display.display_validation_error(message)
- def _prompt_string(self, prompt_text: str, default: Any = None, is_sensitive: bool = False) -> str:
- value = Prompt.ask(
- prompt_text,
- default=str(default) if default is not None else "",
- show_default=True,
- password=is_sensitive
- )
- if value is None:
- return None
- stripped = value.strip()
- return stripped if stripped != "" else None
- def _prompt_bool(self, prompt_text: str, default: Any = None) -> bool:
- default_bool = None
- if default is not None:
- default_bool = default if isinstance(default, bool) else str(default).lower() in ("true", "1", "yes", "on")
- return Confirm.ask(prompt_text, default=default_bool)
- def _prompt_int(self, prompt_text: str, default: Any = None) -> int:
- default_int = None
- if default is not None:
- try:
- default_int = int(default)
- except (ValueError, TypeError):
- logger.warning(f"Invalid default integer value: {default}")
- return IntPrompt.ask(prompt_text, default=default_int)
- def _prompt_enum(self, prompt_text: str, options: list[str], default: Any = None, extra: str | None = None) -> str:
- """Prompt for enum selection with validation.
-
- Note: prompt_text should already include hint from variable.get_validation_hint()
- but we keep this for backward compatibility and fallback.
- """
- if not options:
- return self._prompt_string(prompt_text, default)
- # Validate default is in options
- if default and str(default) not in options:
- default = options[0]
- while True:
- value = Prompt.ask(
- prompt_text,
- default=str(default) if default else options[0],
- show_default=True,
- )
- if value in options:
- return value
- self.console.print(f"[red]Invalid choice. Select from: {', '.join(options)}[/red]")
- # !SECTION
|