# SPDX-License-Identifier: LGPL-3.0-or-later
# cython: language_level=3
# cython: linetrace=True
"""Provide utils for ReacNetGenerator."""
import asyncio
import hashlib
import itertools
import os
import pickle
import shutil
from contextlib import ExitStack
from multiprocessing import Pool, Semaphore
from typing import (
IO,
TYPE_CHECKING,
Any,
AnyStr,
BinaryIO,
Callable,
Generator,
Iterable,
List,
Optional,
Tuple,
Union,
)
import lz4.frame
import numpy as np
import requests
from requests.adapters import HTTPAdapter
from tqdm.auto import tqdm
from ._logging import logger
if TYPE_CHECKING:
import multiprocessing.pool
import multiprocessing.synchronize
import reacnetgenerator
[docs]
class WriteBuffer:
"""Store a buffer for writing files.
It is expensive to write to a file, so we need to make a buffer.
Parameters
----------
f: fileObject
The file object to write.
linenumber: int, default: 1200
The number of contents to store in the buffer. The buffer will be flushed
if it exceeds the set number.
sep: str or bytes, default: None
The separator for contents. If None (default), there will be no separator.
"""
def __init__(
self, f: IO, linenumber: int = 1200, sep: Optional[AnyStr] = None
) -> None:
self.f = f
if sep is not None:
self.sep = sep
elif f.mode == "w":
self.sep = ""
elif f.mode == "wb":
self.sep = b""
else:
raise RuntimeError("File mode should be w or wb!")
self.linenumber = linenumber
self.buff = []
self.name = self.f.name
[docs]
def append(self, text: AnyStr) -> None:
"""Append a text.
Parameters
----------
text : str or bytes
The text to be appended.
"""
self.buff.append(text)
self.check()
[docs]
def extend(self, text: Iterable[AnyStr]) -> None:
"""Extend texts.
Parameters
----------
text : list of strs or bytes
Texts to be extended.
"""
self.buff.extend(text)
self.check()
[docs]
def check(self) -> None:
"""Check if the number of stored contents exceeds.
If so, the buffer will be flushed.
"""
if len(self.buff) > self.linenumber:
self.flush()
[docs]
def flush(self) -> None:
"""Flush the buffer."""
if self.buff:
self.f.writelines([self.sep.join(self.buff), self.sep])
self.buff[:] = []
def __enter__(self) -> "WriteBuffer":
"""Enter the context."""
return self
def __exit__(self, exc_type, exc_value, traceback):
"""Exit the context."""
self.flush()
self.f.__exit__(exc_type, exc_value, traceback)
[docs]
def appendIfNotNone(f: Union[WriteBuffer, ExitStack], wbytes: Optional[AnyStr]) -> None:
"""Append a line to a file if the line is not None.
Parameters
----------
f : WriteBuffer
The file to write.
wbytes : str or bytes
The line to write.
"""
if wbytes is not None:
assert not isinstance(f, ExitStack)
f.append(wbytes)
[docs]
def produce(
semaphore: "multiprocessing.synchronize.Semaphore",
plist: Iterable[Any],
parameter: Any,
) -> Generator[Tuple[Any, Any], None, None]:
"""Item producer with a semaphore.
Prevent large memory usage due to slow IO.
Parameters
----------
semaphore : multiprocessing.Semaphore
The semaphore to acquire.
plist : list of objects
The list of items to be passed.
parameter : object
The parameter yielded with each item.
Yields
------
item: object
The item to be yielded.
parameter: object
The parameter yielded with each item.
"""
for item in plist:
semaphore.acquire()
if parameter is not None:
item = (item, parameter)
yield item
[docs]
def compress(x: Union[str, bytes]) -> bytes:
"""Compress the line.
This function reduces IO overhead to speed up the program. The functions will
use lz4 to compress, since lz4 has better performance
that any others.
The compressed format is size + data + size + data + ..., where size is a 64-bit
little-endian integer.
Parameters
----------
x : str or bytes
The line to compress.
Returns
-------
bytes
The compressed line, with a linebreak in the end.
"""
if isinstance(x, str):
x = x.encode()
compress_block = lz4.frame.compress(x, compression_level=0)
length_bytes = len(compress_block).to_bytes(64, byteorder="little")
return length_bytes + compress_block
[docs]
def decompress(x: bytes, isbytes: bool = False) -> Union[str, bytes]:
"""Decompress the line.
Parameters
----------
x : bytes
The line to decompress.
isbytes : bool, optional, default: False
If the decompressed content is bytes. If not, the line will be decoded.
Returns
-------
str or bytes
The decompressed line.
"""
x = lz4.frame.decompress(x[64:])
if isbytes:
return x
return x.decode()
[docs]
def listtobytes(x: Any) -> bytes:
"""Convert an object to a compressed line.
Parameters
----------
x : object
The object to convert, such as numpy.ndarray.
Returns
-------
bytes
The compressed line.
"""
return compress(pickle.dumps(x))
[docs]
def read_compressed_block(f: BinaryIO) -> Generator[bytes, None, None]:
"""Read compressed binary file, assuming the format is size + data + size + data + ...
Parameters
----------
f : fileObject
The file object to read.
Yields
------
data: bytes
The compressed block.
"""
while True:
sizeb = f.read(64)
if not sizeb:
break
size = int.from_bytes(sizeb, byteorder="little")
yield sizeb + f.read(size)
[docs]
def bytestolist(x: bytes) -> Any:
"""Convert a compressed line to an object.
Parameters
----------
x : bytes
The compressed line.
Returns
-------
object
The decompressed object.
"""
data = decompress(x, isbytes=True)
assert isinstance(data, bytes)
return pickle.loads(data)
[docs]
def listtostirng(
l: Union[str, list, tuple, np.ndarray], sep: Union[List[str], Tuple[str, ...]]
) -> str:
"""Convert a list to string, that is easier to store.
Parameters
----------
l : str or array-like
The list to convert, which can contain any number of dimensions.
sep : list of strs
The seperators for each dimension.
Returns
-------
str
The converted string.
"""
if isinstance(l, str):
return l
if isinstance(l, (list, tuple, np.ndarray)):
return sep[0].join(listtostirng(x, sep[1:]) for x in l)
return str(l)
[docs]
def multiopen(
pool: "multiprocessing.pool.Pool",
func: Callable,
l: IO,
semaphore: Optional["multiprocessing.synchronize.Semaphore"] = None,
nlines: Optional[int] = None,
unordered: bool = True,
return_num: bool = False,
start: int = 0,
extra: Optional[Any] = None,
interval: Optional[int] = None,
bar: bool = True,
desc: Optional[str] = None,
unit: str = "it",
total: Optional[int] = None,
) -> Iterable:
"""Return an interated object for process a file with multiple processors.
Parameters
----------
pool : multiprocessing.Pool
The pool for multiprocessing.
func : function
The function to process lines.
l : File object
The file object.
semaphore : multiprocessing.Semaphore, optional, default: None
The semaphore to acquire. If None (default), the object will be passed
without control.
nlines : int, optional, default: None
The number of lines to pass to the function each time. If None (default),
only one line will be passed to the function.
unordered : bool, optional, default: True
Whether the process can be unordered.
return_num : bool, optional, default: False
If True, adds a counter to an iterable.
start : int, optional, default: 0
The start number of the counter.
extra : object, optional, default: None
The extra object passed to the item.
interval : int, optional, default: None
The interval of items that will be passed to the function. For example,
if set to 10, a item will be passed once every 10 items and others will
be dropped.
bar : bool, optional, default: True
If True, show a tqdm bar for the iteration.
desc : str, optional, default: None
The description of the iteration shown in the bar.
unit : str, optional, default: it
The unit of the iteration shown in the bar.
total : int, optional, default: None
The total number of the iteration shown in the bar.
Returns
-------
object
An object that can be iterated.
"""
obj = l
if nlines:
obj = itertools.zip_longest(*[obj] * nlines)
if interval:
obj = itertools.islice(obj, 0, None, interval)
if return_num:
obj = enumerate(obj, start)
if semaphore:
obj = produce(semaphore, obj, extra)
if unordered:
obj = pool.imap_unordered(func, obj, 100)
else:
obj = pool.imap(func, obj, 100)
if bar:
obj = tqdm(obj, desc=desc, unit=unit, total=total, disable=None)
return obj
[docs]
class SCOUROPTIONS:
"""Scour (SVG optimization) options."""
strip_xml_prolog = True
remove_titles = True
remove_descriptions = True
remove_metadata = True
remove_descriptive_elements = True
strip_comments = True
enable_viewboxing = True
strip_xml_space_attribute = True
strip_ids = True
shorten_ids = True
newlines = False
[docs]
class SharedRNGData:
"""Share ReacNetGenerator data with a class of the submodule.
Parameters
----------
rng: reacnetgenerator.ReacNetGenerator
The centered ReacNetGenerator class.
usedRNGKeys: list of strs
Keys that needs to pass from ReacNetGenerator class to the submodule.
returnedRNGKeys: list of strs
Keys that needs to pass from the submodule to ReacNetGenerator class.
extraNoneKeys: list of strs, optional, default: None
Set keys to None, which will be used in the submodule.
"""
def __init__(
self,
rng: "reacnetgenerator.ReacNetGenerator",
usedRNGKeys: List[str],
returnedRNGKeys: List[str],
extraNoneKeys: Optional[List[str]] = None,
) -> None:
self.rng = rng
self.returnedRNGKeys = returnedRNGKeys
for key in usedRNGKeys:
setattr(self, key, getattr(self.rng, key))
for key in returnedRNGKeys:
setattr(self, key, None)
if extraNoneKeys is not None:
for key in extraNoneKeys:
setattr(self, key, None)
[docs]
def returnkeys(self) -> None:
"""Return back keys to ReacNetGenerator class."""
for key in self.returnedRNGKeys:
setattr(self.rng, key, getattr(self, key))
[docs]
def checksha256(filename: str, sha256_check: Union[str, List[str]]):
"""Check sha256 of a file is correct.
Parameters
----------
filename : str
The filename.
sha256_check : str or list of strs
The sha256 to be checked.
Returns
-------
bool
Indicate whether sha256 is correct.
"""
if not os.path.isfile(filename):
return
h = hashlib.sha256()
b = bytearray(128 * 1024)
mv = memoryview(b)
with open(filename, "rb", buffering=0) as f:
for n in iter(lambda: f.readinto(mv), 0):
h.update(mv[:n])
sha256 = h.hexdigest()
logger.info(f"SHA256 of {filename}: {sha256}")
if sha256 in must_be_list(sha256_check):
return True
logger.warning("SHA256 is not correct.")
logger.warning(open(filename).read())
return False
[docs]
async def download_file(
urls: Union[str, List[str]], pathfilename: str, sha256: str
) -> str:
"""Download files from remote urls if not exists.
Parameters
----------
urls: str or list of strs
The url(s) that is available to download.
pathfilename: str
The downloading path of the file.
sha256: str
Sha256 of the file. If not None and match the file, the download will be skiped.
Returns
-------
pathfilename: str
The downloading path of the file.
"""
s = requests.Session()
s.mount("http://", HTTPAdapter(max_retries=3))
s.mount("https://", HTTPAdapter(max_retries=3))
# download if not exists
if os.path.isfile(pathfilename) and (
sha256 is None or checksha256(pathfilename, sha256)
):
return pathfilename
# from https://stackoverflow.com/questions/16694907
for url in must_be_list(urls):
logger.info(f"Try to download {pathfilename} from {url}")
with s.get(url, stream=True) as r, open(pathfilename, "wb") as f:
try:
shutil.copyfileobj(r.raw, f)
break
except requests.exceptions.RequestException as e:
logger.warning(f"Request {pathfilename} Error.", exc_info=e)
else:
raise RuntimeError(f"Cannot download {pathfilename}.")
return pathfilename
[docs]
async def gather_download_files(urls: List[dict]) -> None:
"""Asynchronously download files from remote urls if not exists.
See download_multifiles function for details.
See Also
--------
download_multifiles
"""
await asyncio.gather(
*[
download_file(jdata["url"], jdata["fn"], jdata.get("sha256", None))
for jdata in urls
]
)
[docs]
def download_multifiles(urls: List[dict]) -> None:
"""Download multiple files from dicts.
Parameters
----------
urls : list of dicts
The information of download files. Each dict should contain the following key:
- url: str or list of strs
The url(s) that is available to download.
- pathfilename: str
The downloading path of the file.
- sha256: str, optional, default: None
Sha256 of the file. If not None and match the file, the download will be skiped.
"""
asyncio.run(gather_download_files(urls))
[docs]
def run_mp(nproc: int, **kwargs: Any) -> Iterable[Any]:
"""Process a file with multiple processors.
Parameters
----------
nproc : int
The number of processors to be used.
**kwargs : dict, optional
Other parameters can be found in the `multiopen` method.
Yields
------
object
The yielded object from the `multiopen` method.
See Also
--------
multiopen
"""
pool = Pool(nproc, maxtasksperchild=1000)
semaphore = Semaphore(nproc * 150)
try:
results = multiopen(pool=pool, semaphore=semaphore, **kwargs)
for item in results:
yield item
semaphore.release()
except:
logger.exception("run_mp failed")
pool.terminate()
raise
else:
pool.close()
finally:
pool.join()
[docs]
def must_be_list(obj: Union[Any, List[Any]]) -> List[Any]:
"""Convert a object to a list if the object is not a list.
Parameters
----------
obj : Object
The object to convert.
Returns
-------
obj: list
If the input object is not a list, returns a list that only contains that
object. Otherwise, returns that object.
"""
if isinstance(obj, list):
return obj
return [obj]