worker_offline.py 4.89 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Wed Jan  3 17:10:23 2018

@author: rbaraglia
"""
import os
import argparse
import thread
import logging
import json
import subprocess
import configparser
Rudy BARAGLIA's avatar
Rudy BARAGLIA committed
15
import re
16
import tenacity
17 18 19 20 21 22 23 24 25 26 27 28 29

from ws4py.client.threadedclient import WebSocketClient

#LOADING CONFIGURATION
worker_settings = configparser.ConfigParser()
worker_settings.read('worker.cfg')
SERVER_IP = worker_settings.get('server_params', 'server_ip')
SERVER_PORT = worker_settings.get('server_params', 'server_port')
SERVER_TARGET = worker_settings.get('server_params', 'server_target')
DECODER_COMMAND = worker_settings.get('worker_params', 'decoder_command')
TEMP_FILE_PATH = worker_settings.get('worker_params', 'temp_file_location')
PREPROCESSING = True if worker_settings.get('worker_params', 'preprocessing') == 'true' else False

30 31 32 33
class NoRouteException(Exception):
    pass
class ConnexionRefusedException(Exception):
    pass
34 35 36 37

class WorkerWebSocket(WebSocketClient):
    def __init__(self, uri):
        WebSocketClient.__init__(self, url=uri, heartbeat_freq=10)
Rudy BARAGLIA's avatar
Rudy BARAGLIA committed
38

39 40 41 42 43 44 45 46 47 48
    def opened(self):
        pass
    def guard_timeout(self):
        pass
    def received_message(self, m):
        try:
            json_msg = json.loads(str(m))
        except:
            logging.debug("Message received: %s" % str(m))
        else: 
49
            if 'uuid' in json_msg.keys():
Rudy BARAGLIA's avatar
Rudy BARAGLIA committed
50 51
                self.client_uuid = json_msg['uuid']
                self.fileName = self.client_uuid.replace('-', '')
52
                self.file = json_msg['file'].decode('base64')
Rudy BARAGLIA's avatar
Rudy BARAGLIA committed
53

Rudy BARAGLIA's avatar
Rudy BARAGLIA committed
54
                with open('./wavs/'+self.fileName+'.wav', 'wb') as f:
55
                    f.write(self.file)
Rudy BARAGLIA's avatar
Rudy BARAGLIA committed
56
                logging.debug("FileName received: %s" % self.fileName)
57 58 59
                # TODO: preprocessing ? (sox python)
                if PREPROCESSING:
                    pass
60
                # Offline decoder call
61
                
Rudy BARAGLIA's avatar
Rudy BARAGLIA committed
62
                logging.debug(DECODER_COMMAND + ' ' + TEMP_FILE_PATH + self.fileName+'.wav')
63
                subprocess.call("cd scripts; ./decode.sh ../systems/models "+self.fileName+".wav", shell=True)
64
                
65
                # Check result
Rudy BARAGLIA's avatar
Rudy BARAGLIA committed
66 67
                if os.path.isfile('trans/decode_'+self.fileName+'.log'):
                    with open('trans/decode_'+self.fileName+'.log', 'r') as resultFile:
Rudy BARAGLIA's avatar
Rudy BARAGLIA committed
68
                        result = resultFile.read().strip()
69
                        logging.debug("Transcription is: %s" % result)
Rudy BARAGLIA's avatar
Rudy BARAGLIA committed
70 71 72 73
                        self.send_result(result)
                else:
                    logging.error("Worker Failed to create transcription file")
                    self.send_error("File was not created by worker")
74 75 76 77
                
                # Delete temporary files
                for file in os.listdir(TEMP_FILE_PATH):
                    os.remove(TEMP_FILE_PATH+file)
78 79 80 81 82

    def post(self, m):
        logging.debug('POST received')

    def send_result(self, result=None):
Rudy BARAGLIA's avatar
Rudy BARAGLIA committed
83 84
        msg = json.dumps({u'uuid': self.client_uuid, u'transcription':result, u'trust_ind':u"0.1235"})
        self.client_uuid = None
85
        self.send(msg)
86

Rudy BARAGLIA's avatar
Rudy BARAGLIA committed
87 88 89
    def send_error(self, message):
        msg = json.dumps({u'uuid': self.client_uuid, u'error':message})
        self.send(msg)
90 91 92 93 94 95 96

    def closed(self, code, reason=None): 
        pass
    
    def finish_request(self):
        pass
    
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
@tenacity.retry(
        wait=tenacity.wait.wait_fixed(2),
        stop=tenacity.stop.stop_after_delay(45),
        retry=tenacity.retry_if_exception(ConnexionRefusedException)
    )
def connect_to_server(ws):
    try:
        logging.info("Attempting to connect to server at %s:%s" % (SERVER_IP, SERVER_PORT))
        ws.connect()
        logging.info("Worker succefully connected to server at %s:%s" % (SERVER_IP, SERVER_PORT))
        ws.run_forever()
    except KeyboardInterrupt:
        logging.info("Worker interrupted by user")
        ws.close()
    except Exception, e:
        if "[Errno 113]" in str(e):
            logging.info("Failed to connect")
            raise NoRouteException
        if "[Errno 111]" in str(e):
            logging.info("Failed to connect")
            raise ConnexionRefusedException
        logging.debug(e)
    logging.info("Worker stopped")

121 122 123 124 125 126 127 128
def main():
    parser = argparse.ArgumentParser(description='Worker for linstt-dispatch')
    parser.add_argument('-u', '--uri', default="ws://"+SERVER_IP+":"+SERVER_PORT+SERVER_TARGET, dest="uri", help="Server<-->worker websocket URI")

    args = parser.parse_args()
    #thread.start_new_thread(loop.run, ())
    if not os.path.isdir(TEMP_FILE_PATH):
        os.mkdir(TEMP_FILE_PATH)
129
    print('#'*50)
130
    logging.basicConfig(level=logging.DEBUG, format="%(levelname)8s %(asctime)s %(message)s ")
131
    logging.info('Starting up worker')
132 133
    ws = WorkerWebSocket(args.uri)
    try:
134 135 136
        connect_to_server(ws)
    except Exception:
        logging.error("Worker did not manage to connect to server at %s:%s" % (SERVER_IP, SERVER_PORT))
137 138
if __name__ == '__main__':
    main()