worker_offline.py 5.13 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Wed Jan  3 17:10:23 2018

@author: rbaraglia
"""
import os
import argparse
import thread
import logging
import json
import subprocess
import configparser
Rudy BARAGLIA's avatar
Rudy BARAGLIA committed
15
import re
16
import tenacity
17
from signal_trimming import *
18 19 20 21 22 23 24 25 26 27 28 29 30

from ws4py.client.threadedclient import WebSocketClient

#LOADING CONFIGURATION
worker_settings = configparser.ConfigParser()
worker_settings.read('worker.cfg')
SERVER_IP = worker_settings.get('server_params', 'server_ip')
SERVER_PORT = worker_settings.get('server_params', 'server_port')
SERVER_TARGET = worker_settings.get('server_params', 'server_target')
DECODER_COMMAND = worker_settings.get('worker_params', 'decoder_command')
TEMP_FILE_PATH = worker_settings.get('worker_params', 'temp_file_location')
PREPROCESSING = True if worker_settings.get('worker_params', 'preprocessing') == 'true' else False

31 32 33 34
class NoRouteException(Exception):
    pass
class ConnexionRefusedException(Exception):
    pass
35 36 37 38

class WorkerWebSocket(WebSocketClient):
    def __init__(self, uri):
        WebSocketClient.__init__(self, url=uri, heartbeat_freq=10)
Rudy BARAGLIA's avatar
Rudy BARAGLIA committed
39

40 41 42 43 44 45 46 47 48 49
    def opened(self):
        pass
    def guard_timeout(self):
        pass
    def received_message(self, m):
        try:
            json_msg = json.loads(str(m))
        except:
            logging.debug("Message received: %s" % str(m))
        else: 
50
            if 'uuid' in json_msg.keys():
Rudy BARAGLIA's avatar
Rudy BARAGLIA committed
51 52
                self.client_uuid = json_msg['uuid']
                self.fileName = self.client_uuid.replace('-', '')
53
                self.file = json_msg['file'].decode('base64')
54 55
                self.filepath = TEMP_FILE_PATH+self.fileName+'.wav'
                with open(self.filepath, 'wb') as f:
56
                    f.write(self.file)
Rudy BARAGLIA's avatar
Rudy BARAGLIA committed
57
                logging.debug("FileName received: %s" % self.fileName)
58 59
                # TODO: preprocessing ? (sox python)
                if PREPROCESSING:
60 61
                    logging.debug("Trimming signal")
                    trim_silence_segments(self.filepath,self.filepath, chunk_size=100, threshold_factor=0.85, side_effect_accomodation=0)
62
                # Offline decoder call
63
                
Rudy BARAGLIA's avatar
Rudy BARAGLIA committed
64
                logging.debug(DECODER_COMMAND + ' ' + TEMP_FILE_PATH + self.fileName+'.wav')
65
                subprocess.call("cd scripts; ./decode.sh ../systems/models "+self.fileName+".wav", shell=True)
66
                
67
                # Check result
Rudy BARAGLIA's avatar
Rudy BARAGLIA committed
68 69
                if os.path.isfile('trans/decode_'+self.fileName+'.log'):
                    with open('trans/decode_'+self.fileName+'.log', 'r') as resultFile:
Rudy BARAGLIA's avatar
Rudy BARAGLIA committed
70
                        result = resultFile.read().strip()
71
                        logging.debug("Transcription is: %s" % result)
Rudy BARAGLIA's avatar
Rudy BARAGLIA committed
72 73 74 75
                        self.send_result(result)
                else:
                    logging.error("Worker Failed to create transcription file")
                    self.send_error("File was not created by worker")
76 77 78 79
                
                # Delete temporary files
                for file in os.listdir(TEMP_FILE_PATH):
                    os.remove(TEMP_FILE_PATH+file)
80 81 82 83 84

    def post(self, m):
        logging.debug('POST received')

    def send_result(self, result=None):
Rudy BARAGLIA's avatar
Rudy BARAGLIA committed
85 86
        msg = json.dumps({u'uuid': self.client_uuid, u'transcription':result, u'trust_ind':u"0.1235"})
        self.client_uuid = None
87
        self.send(msg)
88

Rudy BARAGLIA's avatar
Rudy BARAGLIA committed
89 90 91
    def send_error(self, message):
        msg = json.dumps({u'uuid': self.client_uuid, u'error':message})
        self.send(msg)
92 93 94 95 96 97 98

    def closed(self, code, reason=None): 
        pass
    
    def finish_request(self):
        pass
    
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
@tenacity.retry(
        wait=tenacity.wait.wait_fixed(2),
        stop=tenacity.stop.stop_after_delay(45),
        retry=tenacity.retry_if_exception(ConnexionRefusedException)
    )
def connect_to_server(ws):
    try:
        logging.info("Attempting to connect to server at %s:%s" % (SERVER_IP, SERVER_PORT))
        ws.connect()
        logging.info("Worker succefully connected to server at %s:%s" % (SERVER_IP, SERVER_PORT))
        ws.run_forever()
    except KeyboardInterrupt:
        logging.info("Worker interrupted by user")
        ws.close()
    except Exception, e:
        if "[Errno 113]" in str(e):
            logging.info("Failed to connect")
            raise NoRouteException
        if "[Errno 111]" in str(e):
            logging.info("Failed to connect")
            raise ConnexionRefusedException
        logging.debug(e)
    logging.info("Worker stopped")

123 124 125 126 127 128 129 130
def main():
    parser = argparse.ArgumentParser(description='Worker for linstt-dispatch')
    parser.add_argument('-u', '--uri', default="ws://"+SERVER_IP+":"+SERVER_PORT+SERVER_TARGET, dest="uri", help="Server<-->worker websocket URI")

    args = parser.parse_args()
    #thread.start_new_thread(loop.run, ())
    if not os.path.isdir(TEMP_FILE_PATH):
        os.mkdir(TEMP_FILE_PATH)
131
    print('#'*50)
132
    logging.basicConfig(level=logging.DEBUG, format="%(levelname)8s %(asctime)s %(message)s ")
133
    logging.info('Starting up worker')
134 135
    ws = WorkerWebSocket(args.uri)
    try:
136 137 138
        connect_to_server(ws)
    except Exception:
        logging.error("Worker did not manage to connect to server at %s:%s" % (SERVER_IP, SERVER_PORT))
139 140
if __name__ == '__main__':
    main()