diff --git a/web3/providers/ipc.py b/web3/providers/ipc.py index 1c8ab36..dec7daa 100644 --- a/web3/providers/ipc.py +++ b/web3/providers/ipc.py @@ -25,43 +25,19 @@ from .base import JSONBaseProvider @contextlib.contextmanager def get_ipc_socket(ipc_path, timeout=0.1): - - # On Windows named pipe is used. Simulate socket with it. if sys.platform == 'win32': - import win32file - import pywintypes - - class NamedPipe(object): - def __init__(self, ipc_path): - try: - self.handle = win32file.CreateFile( - ipc_path, win32file.GENERIC_READ | win32file.GENERIC_WRITE, - 0, None, win32file.OPEN_EXISTING, 0, None) - except pywintypes.error as err: - raise IOError(err) - - def recv(self, max_length): - (err, data) = win32file.ReadFile(self.handle, max_length) - if err: - raise IOError(err) - return data - - def sendall(self, data): - return win32file.WriteFile(self.handle, data) - - def close(self): - self.handle.close() + # On Windows named pipe is used. Simulate socket with it. + from web3.utils.windows import NamedPipe pipe = NamedPipe(ipc_path) - yield pipe - pipe.close() - + with contextlib.closing(pipe): + yield pipe else: sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock.connect(ipc_path) sock.settimeout(timeout) - yield sock - sock.close() + with contextlib.closing(sock): + yield sock def get_default_ipc_path(testnet=False): diff --git a/web3/utils/windows.py b/web3/utils/windows.py new file mode 100644 index 0000000..0a7c861 --- /dev/null +++ b/web3/utils/windows.py @@ -0,0 +1,31 @@ +import sys + + +if sys.platform != 'win32': + raise ImportError("This module should not be imported on non `win32` platforms") + + +import win32file # noqa: E402 +import pywintypes # noqa: E402 + + +class NamedPipe(object): + def __init__(self, ipc_path): + try: + self.handle = win32file.CreateFile( + ipc_path, win32file.GENERIC_READ | win32file.GENERIC_WRITE, + 0, None, win32file.OPEN_EXISTING, 0, None) + except pywintypes.error as err: + raise IOError(err) + + def recv(self, max_length): + (err, data) = win32file.ReadFile(self.handle, max_length) + if err: + raise IOError(err) + return data + + def sendall(self, data): + return win32file.WriteFile(self.handle, data) + + def close(self): + self.handle.close()