scripts: mypy fixes

This commit is contained in:
Daniel Eklöf 2025-12-26 13:13:01 +01:00
parent cb1e152d99
commit bbebe0f330
No known key found for this signature in database
GPG key ID: 5BBD4992C116573F
5 changed files with 77 additions and 51 deletions

View file

@ -3,13 +3,10 @@
import argparse
import os
import re
import sys
from typing import Dict, Union
class Capability:
def __init__(self, name: str, value: Union[bool, int, str]):
def __init__(self, name: str, value: bool | int | str) -> None:
self._name = name
self._value = value
@ -18,30 +15,42 @@ class Capability:
return self._name
@property
def value(self) -> Union[bool, int, str]:
def value(self) -> bool | int | str:
return self._value
def __lt__(self, other):
def __lt__(self, other: object) -> bool:
if not isinstance(other, Capability):
return NotImplemented
return self._name < other._name
def __le__(self, other):
def __le__(self, other: object) -> bool:
if not isinstance(other, Capability):
return NotImplemented
return self._name <= other._name
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, Capability):
return NotImplemented
return self._name == other._name
def __ne__(self, other):
return self._name != other._name
def __ne__(self, other: object) -> bool:
if not isinstance(other, Capability):
return NotImplemented
return bool(self._name != other._name)
def __gt__(self, other):
return self._name > other._name
def __gt__(self, other: object) -> bool:
if not isinstance(other, Capability):
return NotImplemented
return bool(self._name > other._name)
def __ge__(self, other):
def __ge__(self, other: object) -> bool:
if not isinstance(other, Capability):
return NotImplemented
return self._name >= other._name
class BoolCapability(Capability):
def __init__(self, name: str):
def __init__(self, name: str) -> None:
super().__init__(name, True)
@ -50,11 +59,11 @@ class IntCapability(Capability):
class StringCapability(Capability):
def __init__(self, name: str, value: str):
def __init__(self, name: str, value: str) -> None:
# see terminfo(5) for valid escape sequences
# Control characters
def translate_ctrl_chr(m):
def translate_ctrl_chr(m: re.Match[str]) -> str:
ctrl = m.group(1)
if ctrl == '?':
return '\\x7f'
@ -83,10 +92,10 @@ class StringCapability(Capability):
class Fragment:
def __init__(self, name: str, description: str):
def __init__(self, name: str, description: str) -> None:
self._name = name
self._description = description
self._caps = {}
self._caps = dict[str, Capability]()
@property
def name(self) -> str:
@ -97,18 +106,18 @@ class Fragment:
return self._description
@property
def caps(self) -> Dict[str, Capability]:
def caps(self) -> dict[str, Capability]:
return self._caps
def add_capability(self, cap: Capability):
def add_capability(self, cap: Capability) -> None:
assert cap.name not in self._caps
self._caps[cap.name] = cap
def del_capability(self, name: str):
def del_capability(self, name: str) -> None:
del self._caps[name]
def main():
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument('source_entry_name')
parser.add_argument('source', type=argparse.FileType('r'))
@ -121,15 +130,15 @@ def main():
source = opts.source
target = opts.target
lines = []
for l in source.readlines():
l = l.strip()
if l.startswith('#'):
lines = list[str]()
for line in source.readlines():
line = line.strip()
if line.startswith('#'):
continue
lines.append(l)
lines.append(line)
fragments = {}
cur_fragment = None
fragments = dict[str, Fragment]()
cur_fragment: Fragment | None = None
for m in re.finditer(
r'(?P<name>(?P<entry_name>[-+\w@]+)\|(?P<entry_desc>.+?),)|'
@ -148,17 +157,20 @@ def main():
elif m.group('bool_cap') is not None:
name = m.group('bool_name')
assert cur_fragment is not None
cur_fragment.add_capability(BoolCapability(name))
elif m.group('int_cap') is not None:
name = m.group('int_name')
value = int(m.group('int_val'), 0)
cur_fragment.add_capability(IntCapability(name, value))
int_value = int(m.group('int_val'), 0)
assert cur_fragment is not None
cur_fragment.add_capability(IntCapability(name, int_value))
elif m.group('str_cap') is not None:
name = m.group('str_name')
value = m.group('str_val')
cur_fragment.add_capability(StringCapability(name, value))
str_value = m.group('str_val')
assert cur_fragment is not None
cur_fragment.add_capability(StringCapability(name, str_value))
else:
assert False
@ -167,6 +179,9 @@ def main():
for frag in fragments.values():
for cap in frag.caps.values():
if cap.name == 'use':
assert isinstance(cap, StringCapability)
assert isinstance(cap.value, str)
use_frag = fragments[cap.value]
for use_cap in use_frag.caps.values():
frag.add_capability(use_cap)
@ -188,7 +203,7 @@ def main():
entry.add_capability(IntCapability('RGB', 8)) # 8 bits per channel
entry.add_capability(StringCapability('query-os-name', os.uname().sysname))
terminfo_parts = []
terminfo_parts = list[str]()
for cap in sorted(entry.caps.values()):
name = cap.name
value = str(cap.value)
@ -212,4 +227,4 @@ def main():
if __name__ == '__main__':
sys.exit(main())
main()