Project
Loading...
Searching...
No Matches
PIDBase.cxx
Go to the documentation of this file.
1// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3// All rights not expressly granted are reserved.
4//
5// This software is distributed under the terms of the GNU General Public
6// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7//
8// In applying this license CERN does not waive the privileges and immunities
9// granted to it by virtue of its status as an Intergovernmental Organization
10// or submit itself to any jurisdiction.
11
14
15#include "TRDPID/PIDBase.h"
16#include "DataFormatsTRD/PID.h"
17#include "Framework/Logger.h"
18
19#ifdef TRDPID_WITH_ONNX
20#include "TRDPID/ML.h"
21#endif
22#include "TRDPID/LQND.h"
23#include "TRDPID/Dummy.h"
24
25namespace o2
26{
27namespace trd
28{
29
30std::array<float, constants::NCHARGES> PIDBase::getCharges(const Tracklet64& tracklet, const int layer, const TrackTRD& trk, const o2::globaltracking::RecoContainer& input, float snp, float tgl) const noexcept
31{
32 // Check z-row merging needs to be performed to recover full charge information
33 if (trk.getIsCrossingNeighbor(layer) && trk.getHasNeighbor()) { // tracklet needs correction
34 for (const auto& trklt : input.getTRDTracklets()) { // search for nearby tracklet
35 if (std::abs(tracklet.getPadCol() - trklt.getPadCol()) <= 1 && std::abs(tracklet.getPadRow() - trklt.getPadRow()) == 1) {
36 if (tracklet.getTrackletWord() == trklt.getTrackletWord()) { // skip original tracklet
37 continue;
38 }
39
40 // Add charge information
41 const auto [aQ0, aQ1, aQ2] = correctCharges(tracklet, snp, tgl);
42 const auto [bQ0, bQ1, bQ2] = correctCharges(tracklet, snp, tgl);
43 return {aQ0 + bQ0, aQ1 + bQ1, aQ2 + bQ2};
44 }
45 }
46 }
47
48 return correctCharges(tracklet, snp, tgl);
49}
50
51std::array<float, constants::NCHARGES> PIDBase::correctCharges(const Tracklet64& trklt, float snp, float tgl) const noexcept
52{
53 auto tphi = snp / std::sqrt((1.f - snp) + (1.f + snp));
54 auto trackletLength = std::sqrt(1.f + tphi * tphi + tgl * tgl);
55 const float correction = mLocalGain->getValue(trklt.getHCID() / 2, trklt.getPadCol(), trklt.getPadRow()) * trackletLength;
56 return {
57 trklt.getQ0() / correction,
58 trklt.getQ1() / correction,
59 trklt.getQ2() / correction,
60 };
61}
62
63std::unique_ptr<PIDBase> getTRDPIDPolicy(PIDPolicy policy)
64{
65 LOG(info) << "Creating PID policy. Loading model " << policy;
66 switch (policy) {
67 case PIDPolicy::LQ1D:
68 return std::make_unique<LQ1D>(PIDPolicy::LQ1D);
69 case PIDPolicy::LQ2D:
70 return std::make_unique<LQ2D>(PIDPolicy::LQ2D);
71 case PIDPolicy::LQ3D:
72 return std::make_unique<LQ3D>(PIDPolicy::LQ3D);
73#ifdef TRDPID_WITH_ONNX // Add all policies that use ONNX in this ifdef
74 case PIDPolicy::XGB:
75 return std::make_unique<XGB>(PIDPolicy::XGB);
76 case PIDPolicy::PY:
77 return std::make_unique<PY>(PIDPolicy::PY);
78#endif
80 return std::make_unique<Dummy>(PIDPolicy::Dummy);
81 default:
82 return nullptr;
83 }
84 return nullptr; // cannot be reached
85}
86
87} // namespace trd
88} // namespace o2
This file provides a dummy model, which only outputs -1.f.
This file provides the interface for loglikehood policies.
This file provides the base for ML policies.
This file provides the base interface for pid policies.
std::array< float, constants::NCHARGES > getCharges(const Tracklet64 &tracklet, const int layer, const TrackTRD &trk, const o2::globaltracking::RecoContainer &input, float snp, float tgl) const noexcept
Definition PIDBase.cxx:30
GLenum GLuint GLint GLint layer
Definition glcorearb.h:1310
std::unique_ptr< PIDBase > getTRDPIDPolicy(PIDPolicy policy)
Factory function to create a PID policy.
Definition PIDBase.cxx:63
PIDPolicy
Option for available PID policies.
Definition PID.h:29
@ LQ3D
3-Dimensional Likelihood model
@ LQ2D
2-Dimensional Likelihood model
@ Dummy
Dummy object outputting -1.f.
@ LQ1D
1-Dimensional Likelihood model
a couple of static helper functions to create timestamp values for CCDB queries or override obsolete ...
LOG(info)<< "Compressed in "<< sw.CpuTime()<< " s"