1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
//! The user-facing JSON web server that listens for inference requests. This
//! is the "front end". The inference route is automatically created, and
//! distributes inference computation across the array of workers.

use super::WebError;

use crate::manager::{self, Manager};

use crate::rpc;
use crate::rpc::worker_client::WorkerClient;
use tonic::Request;

use crate::torch::{Image, InputData};
use crate::worker::WorkerStatus;
use crate::{config, torch};

use actix_web::{get, post, web, HttpRequest, Responder};
use anyhow::anyhow;
use base64::{engine::general_purpose, Engine as _};
use tracing::*;

use std::sync::RwLock;

type Result<T> = std::result::Result<T, WebError>;

#[post("/inference")]
pub async fn inference(
    req: web::Json<torch::InferenceTask>,
    state: web::Data<RwLock<Manager>>,
) -> Result<impl Responder> {
    // Parse the input request
    let input = req.into_inner();
    info!("got inference request: {:?}", input);

    // Get a handle to an idle worker
    let worker = {
        let worker = {
            let m = state.read().unwrap();
            m.get_idle_worker()
        };

        debug!("found idle worker");

        match worker {
            Some(worker) => Ok::<manager::Handle, anyhow::Error>(worker),
            None => {
                warn!("all workers are busy");
                Err(anyhow!("all workers are busy"))
                /*
                if config::AUTO_SCALE {
                    // Start a new worker
                    info!("dynamically starting a new worker");
                    manager.start_new_workers(1).await?;
                    let worker = {
                        let worker = manager.get_idle_worker();
                        if let Some(worker) = worker {
                            manager.set_worker_status(worker.pid, WorkerStatus::Working);
                            debug!("set idle worker to busy");
                            Ok(worker)
                        } else {
                            Err(anyhow!(
                                "all workers are busy, even after starting a new worker"
                            ))
                        }
                    }?;
                    Ok(worker)
                } else {
                    Err(anyhow!("all workers are busy: retry again later"))
                }
                */
            }
        }
    }?;

    // Send the inference request to the worker via RPC
    let channel = worker.channel.clone();
    debug!("sending inference request");

    // Mark the work as busy
    let fast_workers = {
        let s = state.read().unwrap();
        s.config.get_bool("manager.fast_workers")?
    };

    //let output = Manager::run_inference(channel, input).await?;

    let mut worker_client = WorkerClient::new(channel);
    let ty = input.inference_type.clone();
    let req = Request::new(input.into());

    if !fast_workers {
        let mut manager = state.write().unwrap();
        manager.set_worker_status(worker.pid, WorkerStatus::Working);
        debug!("set idle worker to busy");
    }
    let rpc_output: rpc::Inference = worker_client
        .compute_inference(req)
        .await
        .unwrap()
        .into_inner();

    // Mark the worker as Idle again
    {
        let mut manager = state.write().unwrap();
        manager.set_worker_status(worker.pid, WorkerStatus::Idle);
    }

    // Parse output
    let output = match ty {
        torch::InferenceType::ImageClassification { .. } => {
            let classes: Vec<torch::Class> = rpc_output.classification.unwrap().into();
            torch::Inference::Classification(classes)
        }
        torch::InferenceType::ImageToImage => {
            torch::Inference::B64Image(rpc_output.image.unwrap().into())
        }
        _ => unimplemented!(),
    };

    let res = (
        output,
        std::time::Duration::from_secs_f32(rpc_output.duration),
    );

    debug!("received inference response");

    info!("finished serving inference request");

    //Ok(web::Json(output))
    Ok(web::Json(res))
}

/// HTTP request to get the status of all workers
#[get("/workers/_status")]
pub async fn worker_status(
    _req: HttpRequest,
    state: web::Data<RwLock<Manager>>,
) -> Result<impl Responder> {
    let status = {
        let manager = state.read().unwrap();
        manager.all_status().unwrap()
    };

    Ok(web::Json(status))
}

/// HTTP request to get all Working workers
#[get("/workers")]
pub async fn all_workers(_req: HttpRequest, state: web::Data<RwLock<Manager>>) -> impl Responder {
    let workers = {
        let manager = state.read().unwrap();
        manager.all_workers().unwrap()
    };

    web::Json(workers)
}

/// HTTP request to get server statistics
#[get("/workers/_info")]
pub async fn worker_info(
    _req: HttpRequest,
    state: web::Data<RwLock<Manager>>,
) -> Result<impl Responder> {
    let manager = state.read().unwrap();
    let stats = manager.all_stats().await?;
    let stats_list = stats.into_iter().collect::<Vec<_>>();
    Ok(web::Json(stats_list))
}