dtls.pyx 9.30 KiB
# cython: language_level=2
cimport tdtls
from tdtls cimport dtls_context_t, dtls_handler_t, session_t, dtls_alert_level_t, dtls_credentials_type_t
from libc.stdint cimport uint8_t
from libc.stddef cimport size_t
ctypedef uint8_t uint8
import socket
from libc cimport string
DTLS_CLIENT = tdtls.DTLS_CLIENT
DTLS_SERVER = tdtls.DTLS_SERVER
DTLS_LOG_EMERG = tdtls.DTLS_LOG_EMERG
DTLS_LOG_ALERT = tdtls.DTLS_LOG_ALERT
DTLS_LOG_CRIT = tdtls.DTLS_LOG_CRIT
DTLS_LOG_WARN = tdtls.DTLS_LOG_WARN
DTLS_LOG_NOTICE = tdtls.DTLS_LOG_NOTICE
DTLS_LOG_INFO = tdtls.DTLS_LOG_INFO
DTLS_LOG_DEBUG = tdtls.DTLS_LOG_DEBUG
cdef int _write(dtls_context_t *ctx, session_t *session, uint8 *buf, size_t len) except -1:
"""Send data to socket"""
self = <object>(ctx.app)
data = buf[:len]
assert session.addr.sin6.sin6_family == socket.AF_INET6
ip = socket.inet_ntop(socket.AF_INET6, session.addr.sin6.sin6_addr.s6_addr[:16])
port = socket.ntohs(session.addr.sin6.sin6_port)
cdef int ret = self.pycb['write']((ip, port), data)
return ret
cdef int _read(dtls_context_t *ctx, session_t *session, uint8 *buf, size_t len) except -1:
"""Send data to application"""
self = <object>(ctx.app)
data = buf[:len]
assert session.addr.sin6.sin6_family == socket.AF_INET6
ip = socket.inet_ntop(socket.AF_INET6, session.addr.sin6.sin6_addr.s6_addr[:16])
port = socket.ntohs(session.addr.sin6.sin6_port)
cdef int ret = self.pycb['read']((ip, port), data)
return ret
cdef int _event(dtls_context_t *ctx, session_t *session, dtls_alert_level_t level, unsigned short code) except -1:
"""The event handler is called when a message from the alert protocol is received or the state of the DTLS session changes."""
self = <object>(ctx.app)
if self.pycb['event'] != None:
self.pycb['event'](level, code)
else:
print "event:", hex(level), hex(code)
return 0;
cdef int _get_psk_info(dtls_context_t *ctx,
const session_t *session,
dtls_credentials_type_t req_type,
const unsigned char *desc_data,
size_t desc_len,
unsigned char *result_data,
size_t result_length) except -1:
"""Called during handshake to get information related to the psk key exchange.
The type of information requested is indicated by @p type
which will be one of DTLS_PSK_HINT, DTLS_PSK_IDENTITY, or DTLS_PSK_KEY.
The called function must store the requested item in the buffer @p result
of size @p result_length.
On success, the function must return
the actual number of bytes written to @p result, or a
value less than zero on error. The parameter @p desc may
contain additional request information (e.g. the psk_identity
for which a key is requested when @p type == @c DTLS_PSK_KEY.
@param ctx The current dtls context.
@param session The session where the key will be used.
@param type The type of the requested information.
@param desc Additional request information
@param desc_len The actual length of desc.
@param result Must be filled with the requested information.
@param result_length Maximum size of @p result.
@return The number of bytes written to @p result or a value
less than zero on error. """
self = <DTLS>(ctx.app)
assert session.addr.sin6.sin6_family == socket.AF_INET6
ip = socket.inet_ntop(socket.AF_INET6, session.addr.sin6.sin6_addr.s6_addr[:16])
port = socket.ntohs(session.addr.sin6.sin6_port)
desc = desc_data[:desc_len]
#result = result_data[:result_length]
cdef char *tmp
if req_type == tdtls.DTLS_PSK_HINT: # ??? TODO
#print "PSK HINT", ip, port, desc
pass
elif req_type == tdtls.DTLS_PSK_IDENTITY:
#print "PSK ID", ip, port, desc.hex()
l = len(self.pskId)
if result_length >= l:
#result = self.pskId
string.memcpy(result_data, <char*>(self.pskId), l)
#print result_data[:l], result_data[:l].hex(), l
return l
else:
return -1
elif req_type == tdtls.DTLS_PSK_KEY:
#print "PSK KEY", ip, port, desc, desc.hex()
if desc in self.pskStore.keys():
l = len(self.pskStore[desc])
#result = self.pskStore[desc]
tmp = self.pskStore[desc]
string.memcpy(result_data, tmp, l)
#print result_data[:l], result_data[:l].hex(), l
return l
else:
return -1
else:
return -1
return 0
cdef class Session:
cdef session_t session
def __init__(self, addr, int port=0, int flowinfo=0, int scope_id=0):
assert sizeof(self.session.addr.sin6) == 28
self.session.size = sizeof(self.session.addr.sin6)
self.session.addr.sin6.sin6_family = socket.AF_INET6
tmpaddr = socket.inet_pton(self.session.addr.sin6.sin6_family, addr)
string.memcpy(self.session.addr.sin6.sin6_addr.s6_addr, <char*>tmpaddr, 16)
self.session.addr.sin6.sin6_port = socket.htons(port)
self.session.addr.sin6.sin6_flowinfo = flowinfo
self.session.addr.sin6.sin6_scope_id = scope_id
self.session.ifindex = 0
@property
def family(self):
return self.session.addr.sin6.sin6_family
@property
def addr(self):
return socket.inet_ntop(self.session.addr.sin6.sin6_family, self.session.addr.sin6.sin6_addr.s6_addr[:16])
@property
def port(self):
return socket.ntohs(self.session.addr.sin6.sin6_port)
@property
def flowinfo(self):
return self.session.addr.sin6.sin6_flowinfo
@property
def scope_id(self):
return self.session.addr.sin6.sin6_scope_id
@property
def ifindex(self):
return self.session.ifindex
cdef session_t* getSession(self):
return &self.session
cdef p(self):
print "Sesion dump:", self.session.size, self.family, self.addr, self.port, self.flowinfo, self.scope_id, self.ifindex
cdef class Connection(Session):
cdef DTLS d
def __init__(self, DTLS dtls, Session s):
super().__init__(addr = s.addr, port = s.port, flowinfo=s.flowinfo, scope_id=s.scope_id)
self.d = dtls
def __del__(self):
self.d.close(self)
self.d.resetPeer(self)
cdef class MCConnection(Session):
cdef DTLS d
def __init__(self, DTLS dtls, Session s):
super().__init__(addr = s.addr, port = s.port, flowinfo=s.flowinfo, scope_id=s.scope_id)
self.d = dtls
def __del__(self):
self.d.joinLeaveGroupe(self.addr, self.d._sock, join=False)
self.d.resetPeer(self)
cdef class DTLS:
cdef dtls_context_t *ctx
cdef dtls_handler_t cb
cdef public object pycb
cdef public char* pskId
cdef public object pskStore
#@property
#def pycb(self):
#return self.pycb
#@property
#def pskId(self):
#return self.pskId
#@property
#def pskStore(self):
#return self.pskStore
def __cinit__(self):
tdtls.dtls_init()
self.ctx = tdtls.dtls_new_context(<void*>self)
self.cb.write = _write
self.cb.read = _read
self.cb.event = _event
self.cb.get_psk_info = _get_psk_info
tdtls.dtls_set_handler(self.ctx, &self.cb)
def __dealloc__(self):
tdtls.dtls_free_context(self.ctx)
def __init__(self, read=None, write=None, event=None, pskId=b"Id", pskStore={b"Id": b"secret"}):
self.pycb = dict()
if read == None:
read = self.p
self.pycb['read'] = read
if write == None:
write = self.p
self.pycb['write'] = write
self.pycb['event'] = event
self.pskId = pskId
self.pskStore = pskStore
def p(self, x, y):
print "default cb, addr:", x,"data:", y
return len(y)
#int dtls_connect(dtls_context_t *ctx, const session_t *dst)
def connect(self, addr, port=0, flowinfo=0, scope_id=0):
session = Session(addr=addr, port=port, flowinfo=flowinfo, scope_id=scope_id)
#session.p()
ret = tdtls.dtls_connect(self.ctx, session.getSession());
if(ret == 0):
print "already connected to", addr
return Connection(self, session)
elif ret > 0:
return Connection(self, session)
else:
print "error", ret
return None
#int dtls_close(dtls_context_t *ctx, const session_t *remote)
def close(self, Session session: Session):
ret = tdtls.dtls_close(self.ctx, session.getSession())
if ret != 0:
print "Error in close:", ret
raise Exception()
#dtls_peer_t *dtls_get_peer(const dtls_context_t *context, const session_t *session);
#void dtls_reset_peer(dtls_context_t *ctx, dtls_peer_t *peer)
def resetPeer(self, Session session: Session):
tdtls.dtls_reset_peer(self.ctx, tdtls.dtls_get_peer(self.ctx, session.getSession()))
#int dtls_write(dtls_context_t *ctx, session_t *session, uint8 *buf, size_t len)
def write(self, Session remote: Session, data: bytes):
"""send data to remote"""
return tdtls.dtls_write(self.ctx, remote.getSession(), data, len(data))
#void dtls_check_retransmit(dtls_context_t *context, clock_time_t *next)
def checkRetransmit(self):
cdef tdtls.clock_time_t t = 0;
tdtls.dtls_check_retransmit(self.ctx, &t)
return t
#int dtls_handle_message(dtls_context_t *ctx, session_t *session, uint8 *msg, int msglen)
def handleMessage(self, session, msg):
return tdtls.dtls_handle_message(self.ctx, (<Session?>session).getSession(), msg, len(msg))
def handleMessageAddr(self, addr, port, msg):
session = Session(addr, port)
return tdtls.dtls_handle_message(self.ctx, (<Session?>session).getSession(), msg, len(msg))
def setLogLevel(level):
tdtls.dtls_set_log_level(level)
def dtlsGetLogLevel():
return tdtls.dtls_get_log_level()