Source code for snimpy.snmp

#
# snimpy -- Interactive SNMP tool
#
# Copyright (C) Vincent Bernat <bernat@luffy.cx>
#
# Permission to use, copy, modify, and/or distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
#

"""
This module is a low-level interface to build SNMP requests, send
them and receive answers. It is built on top of pysnmp_ but the
exposed interface is far simpler. It is also far less complete and
there is an important dependency to the :mod:`basictypes` module for
type coercing.

.. _pysnmp: http://pysnmp.sourceforge.net/
"""

import re
import socket
import inspect
import threading
import asyncio
import ipaddress
import pysnmp.hlapi.v3arch.asyncio as v3
from pyasn1.type import univ
from pysnmp.proto import rfc1902, rfc1905
from pysnmp.smi import error
from pysnmp.smi.rfc1902 import ObjectType, ObjectIdentity


[docs] class SNMPException(Exception): """SNMP related base exception. All SNMP exceptions are inherited from this one. The inherited exceptions are named after the name of the corresponding SNMP error. """
[docs] class SNMPTooBig(SNMPException): pass
[docs] class SNMPNoSuchName(SNMPException): pass
[docs] class SNMPBadValue(SNMPException): pass
[docs] class SNMPReadOnly(SNMPException): pass
# Dynamically build remaining (v2) exceptions for name, obj in inspect.getmembers(error): if name.endswith("Error") and \ inspect.isclass(obj) and \ issubclass(obj, error.MibOperationError) and \ obj != error.MibOperationError: name = str("SNMP{}".format(name[:-5])) globals()[name] = type(name, (SNMPException,), {}) del name del obj class _LoopGuard: """Cancel pending tasks and close the event loop on thread exit. pysnmp's AsyncioDispatcher creates a background handle_timeout() task that runs forever. When a thread exits, this task must be properly cancelled and awaited — otherwise asyncio emits "Task was destroyed but it is pending!" warnings. This guard is stored in thread-local data alongside the engine and event loop. Because nothing else references it, it is the first object to reach refcount 0 when the thread-local dict is cleared, so its __del__ runs before the engine's dispatcher tries to clean up the same tasks.""" def __init__(self, loop): self._loop = loop def __del__(self): loop = self._loop if loop.is_closed(): return pending = asyncio.all_tasks(loop) for task in pending: task.cancel() if pending: try: loop.run_until_complete( asyncio.gather(*pending, return_exceptions=True)) except RuntimeError: pass loop.close() class _SnimpyEngine: """Manage per-thread SNMP backends and event loop.""" _tls = threading.local() @classmethod def loop(cls): """Return the per-thread event loop, creating one if needed.""" if not hasattr(cls._tls, "loop"): cls._tls.loop = asyncio.new_event_loop() asyncio.set_event_loop(cls._tls.loop) cls._tls.guard = _LoopGuard(cls._tls.loop) return cls._tls.loop @classmethod def engine(cls): """Return the per-thread v3arch SnmpEngine.""" cls.loop() if not hasattr(cls._tls, "engine"): cls._tls.engine = v3.SnmpEngine() return cls._tls.engine
[docs] class Session: """SNMP session. An instance of this object will represent an SNMP session. From such an instance, one can get information from the associated agent.""" def _run(self, coro): """Run an async coroutine synchronously.""" return _SnimpyEngine.loop().run_until_complete(coro) def __init__(self, host, community="public", version=2, secname=None, authprotocol=None, authpassword=None, privprotocol=None, privpassword=None, contextname=None, bulk=40, none=False): """Create a new SNMP session. :param host: The hostname or IP address of the agent to connect to. Optionally, the port can be specified separated with a double colon. :type host: str :param community: The community to transmit to the agent for authorization purpose. This parameter is ignored if the specified version is 3. :type community: str :param version: The SNMP version to use to talk with the agent. Possible values are `1`, `2` (community-based) or `3`. :type version: int :param secname: Security name to use for SNMPv3 only. :type secname: str :param authprotocol: Authorization protocol to use for SNMPv3. This can be `None` or one of the strings `SHA`, `MD5`, `SHA224`, `SHA256`, `SHA384` or `SHA512`. :type authprotocol: None or str :param authpassword: Authorization password if authorization protocol is not `None`. :type authpassword: str :param privprotocol: Privacy protocol to use for SNMPv3. This can be `None` or either the string `AES`, `AES128`, `AES192`, `AES256` or `3DES`. :type privprotocol: None or str :param privpassword: Privacy password if privacy protocol is not `None`. :type contextname: str :param contextname: Context name for SNMPv3 messages. :type privpassword: str :param bulk: Max repetition value for `GETBULK` requests. Set to `0` to disable. :type bulk: int :param none: When enabled, will return None for not found values (instead of raising an exception) :type none: bool """ self._host = host self._version = version self._none = none if version == 1 and none: raise ValueError("None-GET requests not compatible with SNMPv1") # Put authentication stuff in self._auth and select backend if version in [1, 2]: self._auth = v3.CommunityData(community, mpModel=version - 1) self._contextdata = v3.ContextData() self._cmd_args = (_SnimpyEngine.engine(), self._auth) self._get_cmd = v3.get_cmd self._set_cmd = v3.set_cmd self._walk_cmd = v3.walk_cmd self._bulk_walk_cmd = v3.bulk_walk_cmd UdpTarget = v3.UdpTransportTarget Udp6Target = v3.Udp6TransportTarget elif version == 3: if secname is None: secname = community try: authprotocol = { None: v3.usmNoAuthProtocol, "MD5": v3.usmHMACMD5AuthProtocol, "SHA": v3.usmHMACSHAAuthProtocol, "SHA1": v3.usmHMACSHAAuthProtocol, "SHA224": v3.usmHMAC128SHA224AuthProtocol, "SHA256": v3.usmHMAC192SHA256AuthProtocol, "SHA384": v3.usmHMAC256SHA384AuthProtocol, "SHA512": v3.usmHMAC384SHA512AuthProtocol, }[authprotocol] except KeyError: raise ValueError("{} is not an acceptable authentication " "protocol".format(authprotocol)) try: privprotocol = { None: v3.usmNoPrivProtocol, "DES": v3.usmDESPrivProtocol, "3DES": v3.usm3DESEDEPrivProtocol, "AES": v3.usmAesCfb128Protocol, "AES128": v3.usmAesCfb128Protocol, "AES192": v3.usmAesCfb192Protocol, "AES256": v3.usmAesCfb256Protocol, }[privprotocol] except KeyError: raise ValueError("{} is not an acceptable privacy " "protocol".format(privprotocol)) self._auth = v3.UsmUserData(secname, authpassword, privpassword, authprotocol, privprotocol) if contextname: contextdata = v3.ContextData( contextName=rfc1902.OctetString(contextname)) else: contextdata = v3.ContextData() self._cmd_args = (_SnimpyEngine.engine(), self._auth) self._contextdata = contextdata self._get_cmd = v3.get_cmd self._set_cmd = v3.set_cmd self._walk_cmd = v3.walk_cmd self._bulk_walk_cmd = v3.bulk_walk_cmd UdpTarget = v3.UdpTransportTarget Udp6Target = v3.Udp6TransportTarget else: raise ValueError("unsupported SNMP version {}".format(version)) # Put transport stuff into self._transport mo = re.match(r'^(?:' r'\[(?P<ipv6>[\d:A-Fa-f]+)\]|' r'(?P<ipv4>[\d\.]+)|' r'(?P<any>.*?))' r'(?::(?P<port>\d+))?$', host) if mo.group("port"): port = int(mo.group("port")) else: port = 161 if mo.group("ipv6"): self._transport = self._run( Udp6Target.create((mo.group("ipv6"), port))) elif mo.group("ipv4"): self._transport = self._run( UdpTarget.create((mo.group("ipv4"), port))) else: results = socket.getaddrinfo(mo.group("any"), port, 0, socket.SOCK_DGRAM, socket.IPPROTO_UDP) if [x for x in results if x[0] == socket.AF_INET]: self._transport = self._run( UdpTarget.create((mo.group("any"), port))) else: self._transport = self._run( Udp6Target.create((mo.group("any"), port))) self._cmd_args += (self._transport,) # Bulk stuff self.bulk = bulk def _check_exception(self, value): """Check if the given ASN1 value is an exception""" if isinstance(value, rfc1905.NoSuchObject): raise SNMPNoSuchObject("No such object was found") # noqa: F821 if isinstance(value, rfc1905.NoSuchInstance): raise SNMPNoSuchInstance("No such instance exists") # noqa: F821 if isinstance(value, rfc1905.EndOfMibView): raise SNMPEndOfMibView("End of MIB was reached") # noqa: F821 def _check_error(self, errorIndication, errorStatus): """Check for SNMP protocol errors in response""" if errorIndication: self._check_exception(errorIndication) raise SNMPException(str(errorIndication)) if errorStatus: exc = str(errorStatus.prettyPrint()) exc = re.sub(r'\W+', '', exc) exc = "SNMP{}".format(exc[0].upper() + exc[1:]) if str(exc) in globals(): raise globals()[exc] raise SNMPException(errorStatus.prettyPrint()) def _convert(self, value): """Convert a PySNMP value to some native Python type""" try: # With PySNMP 4.3+, an OID is a ObjectIdentity. We try to # extract it while being compatible with earlier releases. value = value.getOid() except AttributeError: pass if self._none: if isinstance(value, rfc1905.NoSuchObject): return None if isinstance(value, rfc1905.NoSuchInstance): return None self._check_exception(value) convertors = {rfc1902.Integer: int, rfc1902.Integer32: int, rfc1902.OctetString: bytes, rfc1902.IpAddress: ipaddress.IPv4Address, rfc1902.Counter32: int, rfc1902.Counter64: int, rfc1902.Gauge32: int, rfc1902.Unsigned32: int, rfc1902.TimeTicks: int, rfc1902.Bits: str, rfc1902.Opaque: str, rfc1902.univ.ObjectIdentifier: tuple, # v1arch returns raw pyasn1 types univ.Integer: int, univ.OctetString: bytes, univ.ObjectIdentifier: tuple} for cl, fn in convertors.items(): if isinstance(value, cl): return fn(value) raise NotImplementedError("unable to convert {}".format(repr(value)))
[docs] def get(self, *oids): """Retrieve an OID value using GET. :param oids: a list of OID to retrieve. An OID is a tuple. :return: a list of tuples with the retrieved OID and the raw value. """ objecttypes = [ObjectType(ObjectIdentity(oid)) for oid in oids] errorIndication, errorStatus, errorIndex, varBinds = self._run( self._get_cmd(*self._cmd_args, self._contextdata, *objecttypes, lookupMib=False)) self._check_error(errorIndication, errorStatus) results = [(tuple(name), self._convert(val)) for name, val in varBinds] if not results: raise SNMPException("empty answer") return tuple(results)
async def _walk_async(self, *oids): """Collect results from GETNEXT-based walk.""" results = [] for oid in oids: walker = self._walk_cmd( *self._cmd_args, self._contextdata, ObjectType(ObjectIdentity(oid)), lookupMib=False, lexicographicMode=False) async for result in walker: errorIndication, errorStatus, errorIndex, varBinds = result self._check_error(errorIndication, errorStatus) for name, val in varBinds: results.append((tuple(name), val)) return results async def _bulkwalk_async(self, bulk, *oids): """Collect results from GETBULK-based walk.""" results = [] for oid in oids: walker = self._bulk_walk_cmd( *self._cmd_args, self._contextdata, 0, bulk, ObjectType(ObjectIdentity(oid)), lookupMib=False, lexicographicMode=False) async for result in walker: errorIndication, errorStatus, errorIndex, varBinds = result self._check_error(errorIndication, errorStatus) for name, val in varBinds: results.append((tuple(name), val)) return results
[docs] def walkmore(self, *oids): """Retrieve OIDs values using GETBULK or GETNEXT. The method is called "walk" but this is either a GETBULK or a GETNEXT. The later is only used for SNMPv1 or if bulk has been disabled using :meth:`bulk` property. :param oids: a list of OID to retrieve. An OID is a tuple. :return: a list of tuples with the retrieved OID and the raw value. """ if self._version == 1 or not self.bulk: results = self._run(self._walk_async(*oids)) else: try: results = self._run(self._bulkwalk_async(self.bulk, *oids)) except SNMPTooBig: # Let's try to ask for less values. We will never increase # bulk again. We cannot increase it just after the walk # because we may end up requesting everything twice (or # more). nbulk = self.bulk / 2 or False if nbulk != self.bulk: self.bulk = nbulk return self.walk(*oids) raise return tuple([(oid, self._convert(val)) for oid, val in results])
[docs] def walk(self, *oids): """Walk from given OIDs but don't return any "extra" results. Only results in the subtree will be returned. :param oid: OIDs used as a start point :return: a list of tuples with the retrieved OID and the raw value. """ return ((noid, result) for oid in oids for noid, result in self.walkmore(oid) if (len(noid) >= len(oid) and noid[:len(oid)] == oid[:len(oid)]))
[docs] def set(self, *args): """Set an OID value using SET. This function takes an odd number of arguments. They are working by pair. The first member is an OID and the second one is :class:`basictypes.Type` instace whose `pack()` method will be used to transform into the appropriate form. :return: a list of tuples with the retrieved OID and the raw value. """ if len(args) % 2 != 0: raise ValueError("expect an even number of arguments for SET") objecttypes = [ObjectType(ObjectIdentity(oid), val.pack()) for oid, val in zip(args[0::2], args[1::2])] errorIndication, errorStatus, errorIndex, varBinds = self._run( self._set_cmd(*self._cmd_args, self._contextdata, *objecttypes, lookupMib=False)) self._check_error(errorIndication, errorStatus) results = [(tuple(name), self._convert(val)) for name, val in varBinds] if not results: raise SNMPException("empty answer") return tuple(results)
def __repr__(self): return "{}(host={},version={})".format( self.__class__.__name__, self._host, self._version) @property def timeout(self): """Get timeout value for the current session. :return: Timeout value in microseconds. """ return self._transport.timeout * 1000000 @timeout.setter def timeout(self, value): """Set timeout value for the current session. :param value: Timeout value in microseconds. """ value = int(value) if value <= 0: raise ValueError("timeout is a positive integer") self._transport.timeout = value / 1000000. @property def retries(self): """Get number of times a request is retried. :return: Number of retries for each request. """ return self._transport.retries @retries.setter def retries(self, value): """Set number of times a request is retried. :param value: Number of retries for each request. """ value = int(value) if value < 0: raise ValueError("retries is a non-negative integer") self._transport.retries = value @property def bulk(self): """Get bulk settings. :return: `False` if bulk is disabled or a non-negative integer for the number of repetitions. """ return self._bulk @bulk.setter def bulk(self, value): """Set bulk settings. :param value: `False` to disable bulk or a non-negative integer for the number of allowed repetitions. """ if value is False: self._bulk = False return value = int(value) if value <= 0: raise ValueError("{} is not an appropriate value " "for max repeater parameter".format( value)) self._bulk = value